This commit is contained in:
takatost 2024-07-16 17:46:20 +08:00
commit 775e52db4d
645 changed files with 15648 additions and 4689 deletions

View File

@ -1,6 +1,7 @@
#!/bin/bash #!/bin/bash
cd web && npm install cd web && npm install
pipx install poetry
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc

View File

@ -75,7 +75,7 @@ jobs:
- name: Run Workflow - name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh run: poetry run -C api bash dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma) - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
uses: hoverkraft-tech/compose-action@v2.0.0 uses: hoverkraft-tech/compose-action@v2.0.0
with: with:
compose-file: | compose-file: |
@ -89,5 +89,6 @@ jobs:
pgvecto-rs pgvecto-rs
pgvector pgvector
chroma chroma
myscale
- name: Test Vector Stores - name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@ -48,18 +48,18 @@ jobs:
platform=${{ matrix.platform }} platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
with: with:
username: ${{ env.DOCKERHUB_USER }} username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }} password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Extract metadata for Docker - name: Extract metadata for Docker
id: meta id: meta
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5

1
.vscode/launch.json vendored
View File

@ -13,7 +13,6 @@
"jinja": true, "jinja": true,
"env": { "env": {
"FLASK_APP": "app.py", "FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True" "GEVENT_SUPPORT": "True"
}, },
"args": [ "args": [

View File

@ -2,17 +2,17 @@
考虑到我们的现状,我们需要灵活快速地交付,但我们也希望确保像你这样的贡献者在贡献过程中获得尽可能顺畅的体验。我们为此编写了这份贡献指南,旨在让你熟悉代码库和我们与贡献者的合作方式,以便你能快速进入有趣的部分。 考虑到我们的现状,我们需要灵活快速地交付,但我们也希望确保像你这样的贡献者在贡献过程中获得尽可能顺畅的体验。我们为此编写了这份贡献指南,旨在让你熟悉代码库和我们与贡献者的合作方式,以便你能快速进入有趣的部分。
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎任何反馈以供我们改进。 这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎提供任何反馈以供我们改进。
在许可方面,请花一分钟阅读我们简短的[许可证和贡献者协议](./LICENSE)。社区还遵守[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。 在许可方面,请花一分钟阅读我们简短的 [许可证和贡献者协议](./LICENSE)。社区还遵守 [行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
## 在开始之前 ## 在开始之前
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或[创建](https://github.com/langgenius/dify/issues/new/choose)一个新问题。我们将问题分为两类: [查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
### 功能请求: ### 功能请求:
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss)制作了一个很好的[功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。 * 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss) 制作了一个很好的 [功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
* 如果您想从现有问题中选择一个,请在其下方留下评论表示您的意愿。 * 如果您想从现有问题中选择一个,请在其下方留下评论表示您的意愿。
@ -20,45 +20,44 @@
根据所提议的功能所属的领域不同,您可能需要与不同的团队成员交流。以下是我们团队成员目前正在从事的各个领域的概述: 根据所提议的功能所属的领域不同,您可能需要与不同的团队成员交流。以下是我们团队成员目前正在从事的各个领域的概述:
| Member | Scope | | 团队成员 | 工作范围 |
| ------------------------------------------------------------ | ---------------------------------------------------- | | ------------------------------------------------------------ | ---------------------------------------------------- |
| [@yeuoly](https://github.com/Yeuoly) | Architecting Agents | | [@yeuoly](https://github.com/Yeuoly) | 架构 Agents |
| [@jyong](https://github.com/JohnJyong) | RAG pipeline design | | [@jyong](https://github.com/JohnJyong) | RAG 流水线设计 |
| [@GarfieldDai](https://github.com/GarfieldDai) | Building workflow orchestrations | | [@GarfieldDai](https://github.com/GarfieldDai) | 构建 workflow 编排 |
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Making our frontend a breeze to use | | [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | 让我们的前端更易用 |
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Developer experience, points of contact for anything | | [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验, 综合事项联系人 |
| [@takatost](https://github.com/takatost) | Overall product direction and architecture | | [@takatost](https://github.com/takatost) | 产品整体方向和架构 |
How we prioritize: 事项优先级:
| Feature Type | Priority | | 功能类型 | 优先级 |
| ------------------------------------------------------------ | --------------- | | ------------------------------------------------------------ | --------------- |
| High-Priority Features as being labeled by a team member | High Priority | | 被团队成员标记为高优先级的功能 | 高优先级 |
| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority | | [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 |
| Non-core features and minor enhancements | Low Priority | | 非核心功能和小幅改进 | 低优先级 |
| Valuable but not immediate | Future-Feature | | 有价值当不紧急 | 未来功能 |
### 其他任何事情例如bug报告、性能优化、拼写错误更正 ### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正):
* 立即开始编码。 * 立即开始编码。
How we prioritize: 事项优先级:
| Issue Type | Priority | | Issue 类型 | 优先级 |
| ------------------------------------------------------------ | --------------- | | ------------------------------------------------------------ | --------------- |
| Bugs in core functions (cannot login, applications not working, security loopholes) | Critical | | 核心功能的 Bugs例如无法登录、应用无法工作、安全漏洞 | 紧急 |
| Non-critical bugs, performance boosts | Medium Priority | | 非紧急 bugs, 性能提升 | 中等优先级 |
| Minor fixes (typos, confusing but working UI) | Low Priority | | 小幅修复(错别字, 能正常工作但存在误导的 UI) | 低优先级 |
## 安装 ## 安装
以下是设置Dify进行开发的步骤 以下是设置 Dify 进行开发的步骤:
### 1. Fork该仓库 ### 1. Fork 该仓库
### 2. 克隆仓库 ### 2. 克隆仓库
从终端克隆fork的仓库: 从终端克隆代码仓库:
``` ```
git clone git@github.com:<github_username>/dify.git git clone git@github.com:<github_username>/dify.git
@ -76,72 +75,72 @@ Dify 依赖以下工具和库:
### 4. 安装 ### 4. 安装
Dify由后端和前端组成。通过`cd api/`导航到后端目录,然后按照[后端README](api/README.md)进行安装。在另一个终端中,通过`cd web/`导航到前端目录,然后按照[前端README](web/README.md)进行安装。 Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
查看[安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq)以获取常见问题列表和故障排除步骤。 查看 [安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq) 以获取常见问题列表和故障排除步骤。
### 5. 在浏览器中访问Dify ### 5. 在浏览器中访问 Dify
为了验证您的设置,打开浏览器并访问[http://localhost:3000](http://localhost:3000)默认或您自定义的URL和端口。现在您应该看到Dify正在运行。 为了验证您的设置,打开浏览器并访问 [http://localhost:3000](http://localhost:3000)(默认或您自定义的 URL 和端口)。现在您应该看到 Dify 正在运行。
## 开发 ## 开发
如果您要添加模型提供程序,请参考[此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。 如果您要添加模型提供程序,请参考 [此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
如果您要向Agent或Workflow添加工具提供程序请参考[此指南](./api/core/tools/README.md)。 如果您要向 Agent Workflow 添加工具提供程序,请参考 [此指南](./api/core/tools/README.md)。
为了帮助您快速了解您的贡献在哪个部分以下是Dify后端和前端的简要注释大纲 为了帮助您快速了解您的贡献在哪个部分,以下是 Dify 后端和前端的简要注释大纲:
### 后端 ### 后端
Dify的后端使用Python编写使用[Flask](https://flask.palletsprojects.com/en/3.0.x/)框架。它使用[SQLAlchemy](https://www.sqlalchemy.org/)作为ORM使用[Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)作为任务队列。授权逻辑通过Flask-login进行处理。 Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsprojects.com/en/3.0.x/) 框架。它使用 [SQLAlchemy](https://www.sqlalchemy.org/) 作为 ORM使用 [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) 作为任务队列。授权逻辑通过 Flask-login 进行处理。
``` ```
[api/] [api/]
├── constants // Constant settings used throughout code base. ├── constants // 用于整个代码库的常量设置。
├── controllers // API route definitions and request handling logic. ├── controllers // API 路由定义和请求处理逻辑。
├── core // Core application orchestration, model integrations, and tools. ├── core // 核心应用编排、模型集成和工具。
├── docker // Docker & containerization related configurations. ├── docker // Docker 和容器化相关配置。
├── events // Event handling and processing ├── events // 事件处理和处理。
├── extensions // Extensions with 3rd party frameworks/platforms. ├── extensions // 与第三方框架/平台的扩展。
├── fields // field definitions for serialization/marshalling. ├── fields // 用于序列化/封装的字段定义。
├── libs // Reusable libraries and helpers. ├── libs // 可重用的库和助手。
├── migrations // Scripts for database migration. ├── migrations // 数据库迁移脚本。
├── models // Database models & schema definitions. ├── models // 数据库模型和架构定义。
├── services // Specifies business logic. ├── services // 指定业务逻辑。
├── storage // Private key storage. ├── storage // 私钥存储。
├── tasks // Handling of async tasks and background jobs. ├── tasks // 异步任务和后台作业的处理。
└── tests └── tests
``` ```
### 前端 ### 前端
该网站使用基于Typescript的[Next.js](https://nextjs.org/)模板进行引导,并使用[Tailwind CSS](https://tailwindcss.com/)进行样式设计。[React-i18next](https://react.i18next.com/)用于国际化。 该网站使用基于 Typescript [Next.js](https://nextjs.org/) 模板进行引导,并使用 [Tailwind CSS](https://tailwindcss.com/) 进行样式设计。[React-i18next](https://react.i18next.com/) 用于国际化。
``` ```
[web/] [web/]
├── app // layouts, pages, and components ├── app // 布局、页面和组件
│ ├── (commonLayout) // common layout used throughout the app │ ├── (commonLayout) // 整个应用通用的布局
│ ├── (shareLayout) // layouts specifically shared across token-specific sessions │ ├── (shareLayout) // 在特定会话中共享的布局
│ ├── activate // activate page │ ├── activate // 激活页面
│ ├── components // shared by pages and layouts │ ├── components // 页面和布局共享的组件
│ ├── install // install page │ ├── install // 安装页面
│ ├── signin // signin page │ ├── signin // 登录页面
│ └── styles // globally shared styles │ └── styles // 全局共享的样式
├── assets // Static assets ├── assets // 静态资源
├── bin // scripts ran at build step ├── bin // 构建步骤运行的脚本
├── config // adjustable settings and options ├── config // 可调整的设置和选项
├── context // shared contexts used by different portions of the app ├── context // 应用中不同部分使用的共享上下文
├── dictionaries // Language-specific translate files ├── dictionaries // 语言特定的翻译文件
├── docker // container configurations ├── docker // 容器配置
├── hooks // Reusable hooks ├── hooks // 可重用的钩子
├── i18n // Internationalization configuration ├── i18n // 国际化配置
├── models // describes data models & shapes of API responses ├── models // 描述数据模型和 API 响应的形状
├── public // meta assets like favicon ├── public // 如 favicon 等元资源
├── service // specifies shapes of API actions ├── service // 定义 API 操作的形状
├── test ├── test
├── types // descriptions of function params and return values ├── types // 函数参数和返回值的描述
└── utils // Shared utility functions └── utils // 共享的实用函数
``` ```
## 提交你的 PR ## 提交你的 PR

View File

@ -83,7 +83,7 @@ OCI_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector # Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Weaviate configuration # Weaviate configuration
@ -106,6 +106,14 @@ MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false MILVUS_SECURE=false
# MyScale configuration
MYSCALE_HOST=127.0.0.1
MYSCALE_PORT=8123
MYSCALE_USER=default
MYSCALE_PASSWORD=
MYSCALE_DATABASE=default
MYSCALE_FTS_PARAMS=
# Relyt configuration # Relyt configuration
RELYT_HOST=127.0.0.1 RELYT_HOST=127.0.0.1
RELYT_PORT=5432 RELYT_PORT=5432
@ -151,6 +159,16 @@ CHROMA_DATABASE=default_database
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
CHROMA_AUTH_CREDENTIALS=difyai123456 CHROMA_AUTH_CREDENTIALS=difyai123456
# AnalyticDB configuration
ANALYTICDB_KEY_ID=your-ak
ANALYTICDB_KEY_SECRET=your-sk
ANALYTICDB_REGION_ID=cn-hangzhou
ANALYTICDB_INSTANCE_ID=gp-ab123456
ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
# OpenSearch configuration # OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1 OPENSEARCH_HOST=127.0.0.1
OPENSEARCH_PORT=9200 OPENSEARCH_PORT=9200
@ -237,4 +255,4 @@ WORKFLOW_CALL_MAX_DEPTH=5
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0

View File

@ -337,6 +337,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name}
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else: else:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -31,6 +31,10 @@ class AppExecutionConfig(BaseSettings):
description='execution timeout in seconds for app execution', description='execution timeout in seconds for app execution',
default=1200, default=1200,
) )
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description='max active request per app, 0 means unlimited',
default=0,
)
class CodeExecutionSandboxConfig(BaseSettings): class CodeExecutionSandboxConfig(BaseSettings):
@ -396,6 +400,11 @@ class DataSetConfig(BaseSettings):
default=30, default=30,
) )
DATASET_OPERATOR_ENABLED: bool = Field(
description='whether to enable dataset operator',
default=False,
)
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """

View File

@ -10,8 +10,10 @@ from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorag
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig from configs.middleware.vdb.pgvector_config import PGVectorConfig
@ -183,8 +185,10 @@ class MiddlewareConfig(
# configs of vdb and vdb providers # configs of vdb and vdb providers
VectorStoreConfig, VectorStoreConfig,
AnalyticdbConfig,
ChromaConfig, ChromaConfig,
MilvusConfig, MilvusConfig,
MyScaleConfig,
OpenSearchConfig, OpenSearchConfig,
OracleConfig, OracleConfig,
PGVectorConfig, PGVectorConfig,

View File

@ -0,0 +1,44 @@
from typing import Optional
from pydantic import BaseModel, Field
class AnalyticdbConfig(BaseModel):
"""
Configuration for connecting to AnalyticDB.
Refer to the following documentation for details on obtaining credentials:
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID : Optional[str] = Field(
default=None,
description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
default=None,
description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID : Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
)
ANALYTICDB_ACCOUNT : Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD : Optional[str] = Field(
default=None,
description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE : Optional[str] = Field(
default=None,
description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance."
)

View File

@ -0,0 +1,39 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
class MyScaleConfig(BaseModel):
"""
MyScale configs
"""
MYSCALE_HOST: Optional[str] = Field(
description='MyScale host',
default=None,
)
MYSCALE_PORT: Optional[PositiveInt] = Field(
description='MyScale port',
default=8123,
)
MYSCALE_USER: Optional[str] = Field(
description='MyScale user',
default=None,
)
MYSCALE_PASSWORD: Optional[str] = Field(
description='MyScale password',
default=None,
)
MYSCALE_DATABASE: Optional[str] = Field(
description='MyScale database name',
default=None,
)
MYSCALE_FTS_PARAMS: Optional[str] = Field(
description='MyScale fts index parameters',
default=None,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description='Dify version', description='Dify version',
default='0.6.12-fix1', default='0.6.14',
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,4 @@
TTS_AUTO_PLAY_TIMEOUT = 5
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02

View File

@ -15,6 +15,7 @@ from fields.app_fields import (
app_pagination_fields, app_pagination_fields,
) )
from libs.login import login_required from libs.login import login_required
from services.app_dsl_service import AppDslService
from services.app_service import AppService from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
@ -97,8 +98,42 @@ class AppImportApi(Resource):
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app = AppDslService.import_and_create_new_app(
app = app_service.import_app(current_user.current_tenant_id, args['data'], args, current_user) tenant_id=current_user.current_tenant_id,
data=args['data'],
args=args,
account=current_user
)
return app, 201
class AppImportFromUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check('apps')
def post(self):
"""Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id,
url=args['url'],
args=args,
account=current_user
)
return app, 201 return app, 201
@ -134,6 +169,7 @@ class AppApi(Resource):
parser.add_argument('description', type=str, location='json') parser.add_argument('description', type=str, location='json')
parser.add_argument('icon', type=str, location='json') parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -176,9 +212,13 @@ class AppCopyApi(Resource):
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() data = AppDslService.export_dsl(app_model=app_model)
data = app_service.export_app(app_model) app = AppDslService.import_and_create_new_app(
app = app_service.import_app(current_user.current_tenant_id, data, args, current_user) tenant_id=current_user.current_tenant_id,
data=data,
args=args,
account=current_user
)
return app, 201 return app, 201
@ -194,10 +234,8 @@ class AppExportApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
app_service = AppService()
return { return {
"data": app_service.export_app(app_model) "data": AppDslService.export_dsl(app_model=app_model)
} }
@ -321,6 +359,7 @@ class AppTraceApi(Resource):
api.add_resource(AppListApi, '/apps') api.add_resource(AppListApi, '/apps')
api.add_resource(AppImportApi, '/apps/import') api.add_resource(AppImportApi, '/apps/import')
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
api.add_resource(AppApi, '/apps/<uuid:app_id>') api.add_resource(AppApi, '/apps/<uuid:app_id>')
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy') api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export') api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')

View File

@ -81,15 +81,36 @@ class ChatMessageTextApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
def post(self, app_model): def post(self, app_model):
from werkzeug.exceptions import InternalServerError
try: try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception:
voice = None
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, app_model=app_model,
text=request.form['text'], text=text,
voice=request.form['voice'], message_id=message_id,
streaming=False voice=voice
) )
return response
return {'data': response.data.decode('latin1')}
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -19,7 +19,12 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
@ -75,7 +80,7 @@ class CompletionMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except ValueError as e: except (ValueError, AppInvokeQuotaExceededError) as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
@ -141,7 +146,7 @@ class ChatMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except ValueError as e: except (ValueError, AppInvokeQuotaExceededError) as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -13,12 +13,14 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import AppInvokeQuotaExceededError
from fields.workflow_fields import workflow_fields from fields.workflow_fields import workflow_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.model import App, AppMode from models.model import App, AppMode
from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -127,8 +129,7 @@ class DraftWorkflowImportApi(Resource):
parser.add_argument('data', type=str, required=True, nullable=False, location='json') parser.add_argument('data', type=str, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
workflow_service = WorkflowService() workflow = AppDslService.import_and_overwrite_workflow(
workflow = workflow_service.import_draft_workflow(
app_model=app_model, app_model=app_model,
data=args['data'], data=args['data'],
account=current_user account=current_user
@ -279,7 +280,7 @@ class DraftWorkflowRunApi(Resource):
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
except ValueError as e: except (ValueError, AppInvokeQuotaExceededError) as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -25,7 +25,7 @@ from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import login_required
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from models.model import ApiToken, UploadFile from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name): def _validate_name(name):
@ -85,6 +85,12 @@ class DatasetListApi(Resource):
else: else:
item['embedding_available'] = True item['embedding_available'] = True
if item.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
item.update({'partial_member_list': part_users_list})
else:
item.update({'partial_member_list': []})
response = { response = {
'data': data, 'data': data,
'has_more': len(datasets) == limit, 'has_more': len(datasets) == limit,
@ -108,8 +114,8 @@ class DatasetListApi(Resource):
help='Invalid indexing technique.') help='Invalid indexing technique.')
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
try: try:
@ -140,6 +146,10 @@ class DatasetApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields) data = marshal(dataset, dataset_detail_fields)
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations( configurations = provider_manager.get_configurations(
@ -163,6 +173,11 @@ class DatasetApi(Resource):
data['embedding_available'] = False data['embedding_available'] = False
else: else:
data['embedding_available'] = True data['embedding_available'] = True
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
return data, 200 return data, 200
@setup_required @setup_required
@ -188,17 +203,21 @@ class DatasetApi(Resource):
nullable=True, nullable=True,
help='Invalid indexing technique.') help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=( parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.') 'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
)
parser.add_argument('embedding_model', type=str, parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.') location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str, parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.') location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
args = parser.parse_args() args = parser.parse_args()
data = request.get_json()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not current_user.is_editor: DatasetPermissionService.check_permission(
raise Forbidden() current_user, dataset, data.get('permission'), data.get('partial_member_list')
)
dataset = DatasetService.update_dataset( dataset = DatasetService.update_dataset(
dataset_id_str, args, current_user) dataset_id_str, args, current_user)
@ -206,7 +225,20 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
return marshal(dataset, dataset_detail_fields), 200 result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get('partial_member_list')
)
else:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({'partial_member_list': partial_member_list})
return result_data, 200
@setup_required @setup_required
@login_required @login_required
@ -215,11 +247,12 @@ class DatasetApi(Resource):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor or current_user.is_dataset_operator:
raise Forbidden() raise Forbidden()
try: try:
if DatasetService.delete_dataset(dataset_id_str, current_user): if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {'result': 'success'}, 204 return {'result': 'success'}, 204
else: else:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -515,7 +548,7 @@ class DatasetRetrievalSettingApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH RetrievalMethod.SEMANTIC_SEARCH
] ]
} }
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH: case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.SEMANTIC_SEARCH,
@ -539,7 +572,7 @@ class DatasetRetrievalSettingMockApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH RetrievalMethod.SEMANTIC_SEARCH
] ]
} }
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH: case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.SEMANTIC_SEARCH,
@ -569,6 +602,27 @@ class DatasetErrorDocs(Resource):
}, 200 }, 200
class DatasetPermissionUserListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
'data': partial_members_list,
}, 200
api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check') api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
@ -582,3 +636,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')

View File

@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
try: try:
@ -294,6 +294,11 @@ class DatasetInitApi(Resource):
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json') location='json')
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
if args['indexing_technique'] == 'high_quality': if args['indexing_technique'] == 'high_quality':
try: try:
model_manager = ModelManager() model_manager = ModelManager()
@ -757,14 +762,18 @@ class DocumentStatusApi(DocumentResource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
# check user's model setting # check user's model setting
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id) # check user's permission
DatasetService.check_dataset_permission(dataset, current_user)
# The role of the current user in the ta table must be admin, owner, or editor document = self.get_document(dataset_id, document_id)
if not current_user.is_editor:
raise Forbidden()
indexing_cache_key = 'document_{}_indexing'.format(document.id) indexing_cache_key = 'document_{}_indexing'.format(document.id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
@ -955,10 +964,11 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required @account_initialization_required
@marshal_with(document_fields) @marshal_with(document_fields)
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not current_user.is_admin_or_owner: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()

View File

@ -19,6 +19,7 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import AppMode
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
@ -70,16 +71,33 @@ class ChatAudioApi(InstalledAppResource):
class ChatTextApi(InstalledAppResource): class ChatTextApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app from flask_restful import reqparse
app_model = installed_app.app
try: try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
except Exception:
voice = None
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, app_model=app_model,
text=request.form['text'], message_id=message_id,
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'), voice=voice
streaming=False
) )
return {'data': response.data.decode('latin1')} return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
@ -108,3 +126,5 @@ class ChatTextApi(InstalledAppResource):
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio') api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text') api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
# endpoint='installed_app_text_with_message_id')

View File

@ -36,7 +36,7 @@ class TagListApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -68,7 +68,7 @@ class TagUpdateDeleteApi(Resource):
def patch(self, tag_id): def patch(self, tag_id):
tag_id = str(tag_id) tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -109,8 +109,8 @@ class TagBindingCreateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not current_user.is_editor: if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -134,8 +134,8 @@ class TagBindingDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not current_user.is_editor: if not (current_user.is_editor or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()

View File

@ -131,7 +131,20 @@ class MemberUpdateRoleApi(Resource):
return {'result': 'success'} return {'result': 'success'}
class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant."""
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {'result': 'success', 'accounts': members}, 200
api.add_resource(MemberListApi, '/workspaces/current/members') api.add_resource(MemberListApi, '/workspaces/current/members')
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email') api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>') api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role') api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')

View File

@ -3,8 +3,9 @@ from functools import wraps
from hashlib import sha1 from hashlib import sha1
from hmac import new as hmac_new from hmac import new as hmac_new
from flask import abort, current_app, request from flask import abort, request
from configs import dify_config
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
@ -12,12 +13,12 @@ from models.model import EndUser
def inner_api_only(view): def inner_api_only(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not current_app.config['INNER_API']: if not dify_config.INNER_API:
abort(404) abort(404)
# get header 'X-Inner-Api-Key' # get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get('X-Inner-Api-Key') inner_api_key = request.headers.get('X-Inner-Api-Key')
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']: if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
abort(404) abort(404)
return view(*args, **kwargs) return view(*args, **kwargs)
@ -28,7 +29,7 @@ def inner_api_only(view):
def inner_api_user_auth(view): def inner_api_user_auth(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not current_app.config['INNER_API']: if not dify_config.INNER_API:
return view(*args, **kwargs) return view(*args, **kwargs)
# get header 'X-Inner-Api-Key' # get header 'X-Inner-Api-Key'

View File

@ -1,7 +1,7 @@
from flask import current_app
from flask_restful import Resource, fields, marshal_with from flask_restful import Resource, fields, marshal_with
from configs import dify_config
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
@ -78,7 +78,7 @@ class AppParameterApi(Resource):
"transfer_methods": ["remote_url", "local_file"] "transfer_methods": ["remote_url", "local_file"]
}}), }}),
'system_parameters': { 'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
} }
} }

View File

@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import App, EndUser from models.model import App, AppMode, EndUser
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
@ -72,19 +72,30 @@ class AudioApi(Resource):
class TextApi(Resource): class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser): def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()
try: try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception:
voice = None
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, app_model=app_model,
text=args['text'], message_id=message_id,
end_user=end_user, end_user=end_user.external_user_id,
voice=args.get('voice'), voice=voice
streaming=args['streaming']
) )
return response return response

View File

@ -17,7 +17,12 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
@ -69,7 +74,7 @@ class CompletionApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except ValueError as e: except (ValueError, AppInvokeQuotaExceededError) as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
@ -132,7 +137,7 @@ class ChatApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except ValueError as e: except (ValueError, AppInvokeQuotaExceededError) as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -14,7 +14,12 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
@ -59,7 +64,7 @@ class WorkflowRunApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except ValueError as e: except (ValueError, AppInvokeQuotaExceededError) as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -1,6 +1,6 @@
from flask import current_app
from flask_restful import Resource from flask_restful import Resource
from configs import dify_config
from controllers.service_api import api from controllers.service_api import api
@ -9,7 +9,7 @@ class IndexApi(Resource):
return { return {
"welcome": "Dify OpenAPI", "welcome": "Dify OpenAPI",
"api_version": "v1", "api_version": "v1",
"server_version": current_app.config['CURRENT_VERSION'] "server_version": dify_config.CURRENT_VERSION,
} }

View File

@ -1,6 +1,6 @@
from flask import current_app
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with
from configs import dify_config
from controllers.web import api from controllers.web import api
from controllers.web.error import AppUnavailableError from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
@ -75,7 +75,7 @@ class AppParameterApi(WebApiResource):
"transfer_methods": ["remote_url", "local_file"] "transfer_methods": ["remote_url", "local_file"]
}}), }}),
'system_parameters': { 'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
} }
} }

View File

@ -19,7 +19,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import App from models.model import App, AppMode
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
@ -69,16 +69,35 @@ class AudioApi(WebApiResource):
class TextApi(WebApiResource): class TextApi(WebApiResource):
def post(self, app_model: App, end_user): def post(self, app_model: App, end_user):
from flask_restful import reqparse
try: try:
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get(
'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
except Exception:
voice = None
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, app_model=app_model,
text=request.form['text'], message_id=message_id,
end_user=end_user.external_user_id, end_user=end_user.external_user_id,
voice=request.form['voice'] if request.form.get('voice') else None, voice=voice
streaming=False
) )
return {'data': response.data.decode('latin1')} return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -1,8 +1,8 @@
from flask import current_app
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.web import api from controllers.web import api
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from extensions.ext_database import db from extensions.ext_database import db
@ -84,7 +84,7 @@ class AppSiteInfo:
self.can_replace_logo = can_replace_logo self.can_replace_logo = can_replace_logo
if can_replace_logo: if can_replace_logo:
base_url = current_app.config.get('FILES_URL') base_url = dify_config.FILES_URL
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
self.custom_config = { self.custom_config = {

View File

@ -0,0 +1,135 @@
import base64
import concurrent.futures
import logging
import queue
import re
import threading
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
class AudioTrunk:
def __init__(self, status: str, audio):
self.audio = audio
self.status = status
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(),
user="responding_tts",
tenant_id=tenant_id,
voice=voice
)
def _process_future(future_queue, audio_queue):
while True:
try:
future = future_queue.get()
if future is None:
break
for audio in future.result():
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
logging.getLogger(__name__).warning(e)
break
audio_queue.put(AudioTrunk("finish", b''))
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ''
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self.match = re.compile(r'[。.!?]')
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices()
values = [voice.get('value') for voice in self.voices]
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get('value')
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message):
try:
self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def _runtime(self):
future_queue = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
message = self._msg_queue.get()
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
self.model_instance, self.tenant_id, self.voice)
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
self.msg_text += message.event.chunk.delta.message.content
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
text_content = ''.join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content,
self.model_instance,
self.tenant_id,
self.voice)
future_queue.put(futures_result)
if text_tmp:
self.msg_text = text_tmp
else:
self.msg_text = ''
except Exception as e:
self.logger.warning(e)
break
future_queue.put(None)
def checkAndGetAudio(self) -> AudioTrunk | None:
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
self.executor.shutdown(wait=False)
return self.last_message
audio = self._audio_queue.get_nowait()
if audio and audio.status == "finish":
self.executor.shutdown(wait=False)
self._runtime_thread = None
if audio:
self._last_audio_event = audio
return audio
except queue.Empty:
return None
def _extract_sentence(self, org_text):
tx = self.match.finditer(org_text)
start = 0
result = []
for i in tx:
end = i.regs[0][1]
result.append(org_text[start:end])
start = end
return result, org_text[start:]

View File

@ -255,6 +255,12 @@ class AdvancedChatAppRunner(AppRunner):
) )
index += 1 index += 1
time.sleep(0.01) time.sleep(0.01)
else:
queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish( queue_manager.publish(
QueueStopEvent(stopped_by=stopped_by), QueueStopEvent(stopped_by=stopped_by),

View File

@ -4,6 +4,8 @@ import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
@ -33,6 +35,8 @@ from core.app.entities.task_entities import (
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute, ChatflowStreamGenerateRoute,
ErrorStreamResponse, ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse, MessageEndStreamResponse,
StreamResponse, StreamResponse,
) )
@ -71,13 +75,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_iteration_nested_relations: dict[str, list[str]] _iteration_nested_relations: dict[str, list[str]]
def __init__( def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity, self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow, workflow: Workflow,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool stream: bool
) -> None: ) -> None:
""" """
Initialize AdvancedChatAppGenerateTaskPipeline. Initialize AdvancedChatAppGenerateTaskPipeline.
@ -129,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._application_generate_entity.query self._application_generate_entity.query
) )
generator = self._process_stream_response( generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager trace_manager=self._application_generate_entity.trace_manager
) )
if self._stream: if self._stream:
@ -138,7 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
-> ChatbotAppBlockingResponse: -> ChatbotAppBlockingResponse:
""" """
Process blocking response. Process blocking response.
:return: :return:
@ -169,7 +173,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
raise Exception('Queue listening stopped unexpectedly.') raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[ChatbotAppStreamResponse, None, None]: -> Generator[ChatbotAppStreamResponse, None, None]:
""" """
To stream response. To stream response.
:return: :return:
@ -182,14 +186,68 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response stream_response=stream_response
) )
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
break
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
Process stream response. Process stream response.
:return: :return:
""" """
for message in self._queue_manager.listen(): for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event event = message.event
if isinstance(event, QueueErrorEvent): if isinstance(event, QueueErrorEvent):
@ -301,7 +359,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
continue continue
if not self._is_stream_out_support( if not self._is_stream_out_support(
event=event event=event
): ):
continue continue
@ -318,7 +376,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._ping_stream_response() yield self._ping_stream_response()
else: else:
continue continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
@ -402,7 +461,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return stream_generate_routes return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]: -> list[str]:
""" """
Get answer start at node id. Get answer start at node id.
:param graph: graph :param graph: graph
@ -457,7 +516,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_node_id = target_node_id start_node_id = target_node_id
start_node_ids.append(start_node_id) start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \ elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id start_node_id = source_node_id
start_node_ids.append(start_node_id) start_node_ids.append(start_node_id)
else: else:
@ -515,7 +574,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# all route chunks are generated # all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len( if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route self._task_state.current_stream_generate_state.generate_route
): ):
self._task_state.current_stream_generate_state = None self._task_state.current_stream_generate_state = None
@ -525,7 +584,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return: :return:
""" """
if not self._task_state.current_stream_generate_state: if not self._task_state.current_stream_generate_state:
return None return
route_chunks = self._task_state.current_stream_generate_state.generate_route[ route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:] self._task_state.current_stream_generate_state.current_route_position:]
@ -573,7 +632,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# get route chunk node execution info # get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM): and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output # only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1 self._task_state.current_stream_generate_state.current_route_position += 1
continue continue
@ -643,7 +702,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# all route chunks are generated # all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len( if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route self._task_state.current_stream_generate_state.generate_route
): ):
self._task_state.current_stream_generate_state = None self._task_state.current_stream_generate_state = None

View File

@ -51,7 +51,6 @@ class AppQueueManager:
listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME") listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME")
start_time = time.time() start_time = time.time()
last_ping_time = 0 last_ping_time = 0
while True: while True:
try: try:
message = self._q.get(timeout=1) message = self._q.get(timeout=1)

View File

@ -1,7 +1,10 @@
import logging import logging
import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
@ -25,6 +28,8 @@ from core.app.entities.queue_entities import (
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
ErrorStreamResponse, ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
StreamResponse, StreamResponse,
TextChunkStreamResponse, TextChunkStreamResponse,
TextReplaceStreamResponse, TextReplaceStreamResponse,
@ -105,7 +110,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
db.session.refresh(self._user) db.session.refresh(self._user)
db.session.close() db.session.close()
generator = self._process_stream_response( generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager trace_manager=self._application_generate_entity.trace_manager
) )
if self._stream: if self._stream:
@ -161,8 +166,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
stream_response=stream_response stream_response=stream_response
) )
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
break
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
@ -170,6 +225,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return: :return:
""" """
for message in self._queue_manager.listen(): for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event event = message.event
if isinstance(event, QueueErrorEvent): if isinstance(event, QueueErrorEvent):
@ -251,6 +308,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else: else:
continue continue
if publisher:
publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
""" """
Save workflow app log. Save workflow app log.

View File

@ -69,6 +69,7 @@ class WorkflowTaskState(TaskState):
iteration_nested_node_ids: list[str] = None iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState): class AdvancedChatTaskState(WorkflowTaskState):
""" """
AdvancedChatTaskState entity AdvancedChatTaskState entity
@ -86,6 +87,8 @@ class StreamEvent(Enum):
ERROR = "error" ERROR = "error"
MESSAGE = "message" MESSAGE = "message"
MESSAGE_END = "message_end" MESSAGE_END = "message_end"
TTS_MESSAGE = "tts_message"
TTS_MESSAGE_END = "tts_message_end"
MESSAGE_FILE = "message_file" MESSAGE_FILE = "message_file"
MESSAGE_REPLACE = "message_replace" MESSAGE_REPLACE = "message_replace"
AGENT_THOUGHT = "agent_thought" AGENT_THOUGHT = "agent_thought"
@ -130,6 +133,22 @@ class MessageStreamResponse(StreamResponse):
answer: str answer: str
class MessageAudioStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE
audio: str
class MessageAudioEndStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
audio: str
class MessageEndStreamResponse(StreamResponse): class MessageEndStreamResponse(StreamResponse):
""" """
MessageEndStreamResponse entity MessageEndStreamResponse entity
@ -186,6 +205,7 @@ class WorkflowStartStreamResponse(StreamResponse):
""" """
WorkflowStartStreamResponse entity WorkflowStartStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -205,6 +225,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
""" """
WorkflowFinishStreamResponse entity WorkflowFinishStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -232,6 +253,7 @@ class NodeStartStreamResponse(StreamResponse):
""" """
NodeStartStreamResponse entity NodeStartStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -273,6 +295,7 @@ class NodeFinishStreamResponse(StreamResponse):
""" """
NodeFinishStreamResponse entity NodeFinishStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -323,10 +346,12 @@ class NodeFinishStreamResponse(StreamResponse):
} }
} }
class IterationNodeStartStreamResponse(StreamResponse): class IterationNodeStartStreamResponse(StreamResponse):
""" """
NodeStartStreamResponse entity NodeStartStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -344,10 +369,12 @@ class IterationNodeStartStreamResponse(StreamResponse):
workflow_run_id: str workflow_run_id: str
data: Data data: Data
class IterationNodeNextStreamResponse(StreamResponse): class IterationNodeNextStreamResponse(StreamResponse):
""" """
NodeStartStreamResponse entity NodeStartStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -365,10 +392,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
workflow_run_id: str workflow_run_id: str
data: Data data: Data
class IterationNodeCompletedStreamResponse(StreamResponse): class IterationNodeCompletedStreamResponse(StreamResponse):
""" """
NodeCompletedStreamResponse entity NodeCompletedStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -393,10 +422,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
workflow_run_id: str workflow_run_id: str
data: Data data: Data
class TextChunkStreamResponse(StreamResponse): class TextChunkStreamResponse(StreamResponse):
""" """
TextChunkStreamResponse entity TextChunkStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -411,6 +442,7 @@ class TextReplaceStreamResponse(StreamResponse):
""" """
TextReplaceStreamResponse entity TextReplaceStreamResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -473,6 +505,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
""" """
ChatbotAppBlockingResponse entity ChatbotAppBlockingResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -492,6 +525,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
""" """
CompletionAppBlockingResponse entity CompletionAppBlockingResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -510,6 +544,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
""" """
WorkflowAppBlockingResponse entity WorkflowAppBlockingResponse entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity
@ -528,10 +563,12 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str workflow_run_id: str
data: Data data: Data
class WorkflowIterationState(BaseModel): class WorkflowIterationState(BaseModel):
""" """
WorkflowIterationState entity WorkflowIterationState entity
""" """
class Data(BaseModel): class Data(BaseModel):
""" """
Data entity Data entity

View File

@ -0,0 +1 @@
from .rate_limit import RateLimit

View File

@ -0,0 +1,120 @@
import logging
import time
import uuid
from collections.abc import Generator
from datetime import timedelta
from typing import Optional, Union
from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class RateLimit:
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {}
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
return cls._instance_dict[client_id]
def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests
if hasattr(self, 'initialized'):
return
self.initialized = True
self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float('-inf')
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
self.last_recalculate_time = time.time()
# flush max active requests
if use_local_value or not redis_client.exists(self.max_active_requests_key):
with redis_client.pipeline() as pipe:
pipe.set(self.max_active_requests_key, self.max_active_requests)
pipe.expire(self.max_active_requests_key, timedelta(days=1))
pipe.execute()
else:
with redis_client.pipeline() as pipe:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
# flush max active requests (in-transit request list)
if not redis_client.exists(self.active_requests_key):
return
request_details = redis_client.hgetall(self.active_requests_key)
redis_client.expire(self.active_requests_key, timedelta(days=1))
timeout_requests = [k for k, v in request_details.items() if
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
def enter(self, request_id: Optional[str] = None) -> str:
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
self.flush_cache()
if self.max_active_requests <= 0:
return RateLimit._UNLIMITED_REQUEST_ID
if not request_id:
request_id = RateLimit.gen_request_key()
active_requests_count = redis_client.hlen(self.active_requests_key)
if active_requests_count >= self.max_active_requests:
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests))
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
return request_id
def exit(self, request_id: str):
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
return
redis_client.hdel(self.active_requests_key, request_id)
@staticmethod
def gen_request_key() -> str:
return str(uuid.uuid4())
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
if isinstance(generator, dict):
return generator
else:
return RateLimitGenerator(self, generator, request_id)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
self.rate_limit = rate_limit
if callable(generator):
self.generator = generator()
else:
self.generator = generator
self.request_id = request_id
self.closed = False
def __iter__(self):
return self
def __next__(self):
if self.closed:
raise StopIteration
try:
return next(self.generator)
except StopIteration:
self.close()
raise
def close(self):
if not self.closed:
self.closed = True
self.rate_limit.exit(self.request_id)
if self.generator is not None and hasattr(self.generator, 'close'):
self.generator.close()

View File

@ -4,6 +4,8 @@ import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity, AgentChatAppGenerateEntity,
@ -32,6 +34,8 @@ from core.app.entities.task_entities import (
CompletionAppStreamResponse, CompletionAppStreamResponse,
EasyUITaskState, EasyUITaskState,
ErrorStreamResponse, ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse, MessageEndStreamResponse,
StreamResponse, StreamResponse,
) )
@ -87,6 +91,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
""" """
super().__init__(application_generate_entity, queue_manager, user, stream) super().__init__(application_generate_entity, queue_manager, user, stream)
self._model_config = application_generate_entity.model_conf self._model_config = application_generate_entity.model_conf
self._app_config = application_generate_entity.app_config
self._conversation = conversation self._conversation = conversation
self._message = message self._message = message
@ -102,7 +107,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._conversation_name_generate_thread = None self._conversation_name_generate_thread = None
def process( def process(
self, self,
) -> Union[ ) -> Union[
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
@ -123,7 +128,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._application_generate_entity.query self._application_generate_entity.query
) )
generator = self._process_stream_response( generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager trace_manager=self._application_generate_entity.trace_manager
) )
if self._stream: if self._stream:
@ -202,14 +207,64 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response stream_response=stream_response
) )
def _listenAudioMsg(self, publisher, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
if publisher is None:
break
audio = publisher.checkAndGetAudio()
if audio is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio,
task_id=task_id)
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
Process stream response. Process stream response.
:return: :return:
""" """
for message in self._queue_manager.listen(): for message in self._queue_manager.listen():
if publisher:
publisher.publish(message)
event = message.event event = message.event
if isinstance(event, QueueErrorEvent): if isinstance(event, QueueErrorEvent):
@ -272,12 +327,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield self._ping_stream_response() yield self._ping_stream_response()
else: else:
continue continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message( def _save_message(
self, trace_manager: Optional[TraceQueueManager] = None self, trace_manager: Optional[TraceQueueManager] = None
) -> None: ) -> None:
""" """
Save message. Save message.

View File

@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
description = "Quota Exceeded" description = "Quota Exceeded"
class AppInvokeQuotaExceededError(Exception):
"""
Custom exception raised when the quota for an app has been exceeded.
"""
description = "App Invoke Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception): class ModelCurrentlyNotSupportError(Exception):
""" """
Custom exception raised when the model not support Custom exception raised when the model not support

View File

@ -730,7 +730,7 @@ class IndexingRunner:
self._check_document_paused_status(dataset_document.id) self._check_document_paused_status(dataset_document.id)
tokens = 0 tokens = 0
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance: if embedding_model_instance:
tokens += sum( tokens += sum(
embedding_model_instance.get_text_embedding_num_tokens( embedding_model_instance.get_text_embedding_num_tokens(
[document.page_content] [document.page_content]

View File

@ -264,7 +264,7 @@ class ModelInstance:
user=user user=user
) )
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
-> str: -> str:
""" """
Invoke large language tts model Invoke large language tts model
@ -287,8 +287,7 @@ class ModelInstance:
content_text=content_text, content_text=content_text,
user=user, user=user,
tenant_id=tenant_id, tenant_id=tenant_id,
voice=voice, voice=voice
streaming=streaming
) )
def _round_robin_invoke(self, function: Callable, *args, **kwargs): def _round_robin_invoke(self, function: Callable, *args, **kwargs):

View File

@ -1,4 +1,6 @@
import hashlib import hashlib
import logging
import re
import subprocess import subprocess
import uuid import uuid
from abc import abstractmethod from abc import abstractmethod
@ -10,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class TTSModel(AIModel): class TTSModel(AIModel):
""" """
Model class for ttstext model. Model class for ttstext model.
@ -20,7 +22,7 @@ class TTSModel(AIModel):
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None): user: Optional[str] = None):
""" """
Invoke large language model Invoke large language model
@ -35,14 +37,15 @@ class TTSModel(AIModel):
:return: translated audio file :return: translated audio file
""" """
try: try:
logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}")
self._is_ffmpeg_installed() self._is_ffmpeg_installed()
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, return self._invoke(model=model, credentials=credentials, user=user,
content_text=content_text, voice=voice, tenant_id=tenant_id) content_text=content_text, voice=voice, tenant_id=tenant_id)
except Exception as e: except Exception as e:
raise self._transform_invoke_error(e) raise self._transform_invoke_error(e)
@abstractmethod @abstractmethod
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None): user: Optional[str] = None):
""" """
Invoke large language model Invoke large language model
@ -123,26 +126,26 @@ class TTSModel(AIModel):
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
@staticmethod @staticmethod
def _split_text_into_sentences(text: str, limit: int, delimiters=None): def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
if delimiters is None: match = re.compile(pattern)
delimiters = set('。!?;\n') tx = match.finditer(org_text)
start = 0
buf = [] result = []
word_count = 0 one_sentence = ''
for char in text: for i in tx:
buf.append(char) end = i.regs[0][1]
if char in delimiters: tmp = org_text[start:end]
if word_count >= limit: if len(one_sentence + tmp) > max_length:
yield ''.join(buf) result.append(one_sentence)
buf = [] one_sentence = ''
word_count = 0 one_sentence += tmp
else: start = end
word_count += 1 last_sens = org_text[start:]
else: if last_sens:
word_count += 1 one_sentence += last_sens
if one_sentence != '':
if buf: result.append(one_sentence)
yield ''.join(buf) return result
@staticmethod @staticmethod
def _is_ffmpeg_installed(): def _is_ffmpeg_installed():

View File

@ -33,3 +33,4 @@
- deepseek - deepseek
- hunyuan - hunyuan
- siliconflow - siliconflow
- perfxcloud

View File

@ -71,6 +71,9 @@ model_credential_schema:
- label: - label:
en_US: '2024-02-01' en_US: '2024-02-01'
value: '2024-02-01' value: '2024-02-01'
- label:
en_US: '2024-06-01'
value: '2024-06-01'
placeholder: placeholder:
zh_Hans: 在此选择您的 API 版本 zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here en_US: Select your API Version here

View File

@ -4,7 +4,7 @@ from functools import reduce
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional
from flask import Response, stream_with_context from flask import Response
from openai import AzureOpenAI from openai import AzureOpenAI
from pydub import AudioSegment from pydub import AudioSegment
@ -14,7 +14,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
from extensions.ext_storage import storage
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
@ -23,7 +22,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
""" """
def _invoke(self, model: str, tenant_id: str, credentials: dict, def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: content_text: str, voice: str, user: Optional[str] = None) -> any:
""" """
_invoke text2speech model _invoke text2speech model
@ -32,30 +31,23 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
:param credentials: model credentials :param credentials: model credentials
:param content_text: text content to be translated :param content_text: text content to be translated
:param voice: model timbre :param voice: model timbre
:param streaming: output is streaming
:param user: unique user id :param user: unique user id
:return: text translated to audio file :return: text translated to audio file
""" """
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials) voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: return self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
voice=voice)
def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
validate credentials text2speech model validate credentials text2speech model
:param model: model name :param model: model name
:param credentials: model credentials :param credentials: model credentials
:param user: unique user id
:return: text translated to audio file :return: text translated to audio file
""" """
try: try:
@ -82,7 +74,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
word_limit = self._get_model_word_limit(model, credentials) word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials)
try: try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
audio_bytes_list = [] audio_bytes_list = []
# Create a thread pool and map the function to the list of sentences # Create a thread pool and map the function to the list of sentences
@ -107,34 +99,37 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
voice: str) -> any: voice: str) -> any:
""" """
_tts_invoke_streaming text2speech model _tts_invoke_streaming text2speech model
:param model: model name :param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials :param credentials: model credentials
:param content_text: text content to be translated :param content_text: text content to be translated
:param voice: model timbre :param voice: model timbre
:return: text translated to audio file :return: text translated to audio file
""" """
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
try: try:
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs) client = AzureOpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) # max font is 4096,there is 3500 limit for each request
for sentence in sentences: max_length = 3500
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) if len(content_text) > max_length:
# response.stream_to_file(file_path) sentences = self._split_text_into_sentences(content_text, max_length=max_length)
storage.save(file_path, response.read()) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
response_format="mp3",
input=sentences[i], voice=voice) for i in range(len(sentences))]
for index, future in enumerate(futures):
yield from future.result().__enter__().iter_bytes(1024)
else:
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
response_format="mp3",
input=content_text.strip())
yield from response.__enter__().iter_bytes(1024)
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))
@ -162,7 +157,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
@staticmethod @staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
for ai_model_entity in TTS_BASE_MODELS: for ai_model_entity in TTS_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name: if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity) ai_model_entity_copy = copy.deepcopy(ai_model_entity)
@ -170,5 +165,4 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy return ai_model_entity_copy
return None return None

View File

@ -66,6 +66,10 @@ provider_credential_schema:
label: label:
en_US: Europe (Frankfurt) en_US: Europe (Frankfurt)
zh_Hans: 欧洲 (法兰克福) zh_Hans: 欧洲 (法兰克福)
- value: eu-west-2
label:
en_US: Eu west London (London)
zh_Hans: 欧洲西部 (伦敦)
- value: us-gov-west-1 - value: us-gov-west-1
label: label:
en_US: AWS GovCloud (US-West) en_US: AWS GovCloud (US-West)

View File

@ -48,6 +48,28 @@ logger = logging.getLogger(__name__)
class BedrockLargeLanguageModel(LargeLanguageModel): class BedrockLargeLanguageModel(LargeLanguageModel):
# please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
CONVERSE_API_ENABLED_MODEL_INFO=[
{'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
]
@staticmethod
def _find_model_info(model_id):
for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
if model_id.startswith(model['prefix']):
return model
logger.info(f"current model id: {model_id} did not support by Converse API")
return None
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
@ -66,10 +88,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param user: unique user id :param user: unique user id
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
# TODO: consolidate different invocation methods for models based on base model capabilities
# invoke anthropic models via boto3 client model_info= BedrockLargeLanguageModel._find_model_info(model)
if "anthropic" in model: if model_info:
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools) model_info['model'] = model
# invoke models via boto3 converse API
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
# invoke Cohere models via boto3 client # invoke Cohere models via boto3 client
if "cohere.command-r" in model: if "cohere.command-r" in model:
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools) return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
@ -151,12 +175,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages)
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
""" """
Invoke Anthropic large language model Invoke large language model with converse API
:param model: model name :param model_info: model information
:param credentials: model credentials :param credentials: model credentials
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:param model_parameters: model parameters :param model_parameters: model parameters
@ -173,24 +197,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
parameters = { parameters = {
'modelId': model, 'modelId': model_info['model'],
'messages': prompt_message_dicts, 'messages': prompt_message_dicts,
'inferenceConfig': inference_config, 'inferenceConfig': inference_config,
'additionalModelRequestFields': additional_model_fields, 'additionalModelRequestFields': additional_model_fields,
} }
if system and len(system) > 0: if model_info['support_system_prompts'] and system and len(system) > 0:
parameters['system'] = system parameters['system'] = system
if tools: if model_info['support_tool_use'] and tools:
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools) parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
if stream: if stream:
response = bedrock_client.converse_stream(**parameters) response = bedrock_client.converse_stream(**parameters)
return self._handle_converse_stream_response(model, credentials, response, prompt_messages) return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
else: else:
response = bedrock_client.converse(**parameters) response = bedrock_client.converse(**parameters)
return self._handle_converse_response(model, credentials, response, prompt_messages) return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
def _handle_converse_response(self, model: str, credentials: dict, response: dict, def _handle_converse_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage]) -> LLMResult: prompt_messages: list[PromptMessage]) -> LLMResult:
@ -203,10 +227,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: full response chunk generator result :return: full response chunk generator result
""" """
response_content = response['output']['message']['content']
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( if response['stopReason'] == 'tool_use':
content=response['output']['message']['content'][0]['text'] tool_calls = []
) text, tool_use = self._extract_tool_use(response_content)
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use['toolUseId'],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use['name'],
arguments=json.dumps(tool_use['input'])
)
)
tool_calls.append(tool_call)
assistant_prompt_message = AssistantPromptMessage(
content=text,
tool_calls=tool_calls
)
else:
assistant_prompt_message = AssistantPromptMessage(
content=response_content[0]['text']
)
# calculate num tokens # calculate num tokens
if response['usage']: if response['usage']:
@ -229,6 +273,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
) )
return result return result
def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
tool_use = {}
text = ''
for item in content:
if 'toolUse' in item:
tool_use = item['toolUse']
elif 'text' in item:
text = item['text']
else:
raise ValueError(f"Got unknown item: {item}")
return text, tool_use
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict, def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage], ) -> Generator: prompt_messages: list[PromptMessage], ) -> Generator:
""" """
@ -340,14 +396,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
""" """
system = [] system = []
prompt_message_dicts = []
for message in prompt_messages: for message in prompt_messages:
if isinstance(message, SystemPromptMessage): if isinstance(message, SystemPromptMessage):
message.content=message.content.strip() message.content=message.content.strip()
system.append({"text": message.content}) system.append({"text": message.content})
else:
prompt_message_dicts = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts return system, prompt_message_dicts
@ -448,7 +502,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
} }
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
return message_dict return message_dict
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str, def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,

View File

@ -2,6 +2,9 @@ model: mistral.mistral-large-2402-v1:0
label: label:
en_US: Mistral Large en_US: Mistral Large
model_type: llm model_type: llm
features:
- tool-call
- agent-thought
model_properties: model_properties:
mode: completion mode: completion
context_size: 32000 context_size: 32000

View File

@ -2,6 +2,8 @@ model: mistral.mistral-small-2402-v1:0
label: label:
en_US: Mistral Small en_US: Mistral Small
model_type: llm model_type: llm
features:
- tool-call
model_properties: model_properties:
mode: completion mode: completion
context_size: 32000 context_size: 32000

View File

@ -7,7 +7,7 @@ features:
- agent-thought - agent-thought
model_properties: model_properties:
mode: chat mode: chat
context_size: 32000 context_size: 128000
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -7,7 +7,7 @@ features:
- agent-thought - agent-thought
model_properties: model_properties:
mode: chat mode: chat
context_size: 32000 context_size: 128000
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

View File

@ -21,7 +21,7 @@ model_properties:
- mode: 'shimmer' - mode: 'shimmer'
name: 'Shimmer' name: 'Shimmer'
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
word_limit: 120 word_limit: 3500
audio_type: 'mp3' audio_type: 'mp3'
max_workers: 5 max_workers: 5
pricing: pricing:

View File

@ -21,7 +21,7 @@ model_properties:
- mode: 'shimmer' - mode: 'shimmer'
name: 'Shimmer' name: 'Shimmer'
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
word_limit: 120 word_limit: 3500
audio_type: 'mp3' audio_type: 'mp3'
max_workers: 5 max_workers: 5
pricing: pricing:

View File

@ -3,7 +3,7 @@ from functools import reduce
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional
from flask import Response, stream_with_context from flask import Response
from openai import OpenAI from openai import OpenAI
from pydub import AudioSegment from pydub import AudioSegment
@ -11,7 +11,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.openai._common import _CommonOpenAI from core.model_runtime.model_providers.openai._common import _CommonOpenAI
from extensions.ext_storage import storage
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
@ -20,7 +19,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
""" """
def _invoke(self, model: str, tenant_id: str, credentials: dict, def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: content_text: str, voice: str, user: Optional[str] = None) -> any:
""" """
_invoke text2speech model _invoke text2speech model
@ -29,22 +28,17 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
:param credentials: model credentials :param credentials: model credentials
:param content_text: text content to be translated :param content_text: text content to be translated
:param voice: model timbre :param voice: model timbre
:param streaming: output is streaming
:param user: unique user id :param user: unique user id
:return: text translated to audio file :return: text translated to audio file
""" """
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials) voice = self._get_model_default_voice(model, credentials)
if streaming: # if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model, return self._tts_invoke_streaming(model=model,
credentials=credentials, credentials=credentials,
content_text=content_text, content_text=content_text,
tenant_id=tenant_id, voice=voice)
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
""" """
@ -79,7 +73,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
word_limit = self._get_model_word_limit(model, credentials) word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials)
try: try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
audio_bytes_list = [] audio_bytes_list = []
# Create a thread pool and map the function to the list of sentences # Create a thread pool and map the function to the list of sentences
@ -104,34 +98,40 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
voice: str) -> any: voice: str) -> any:
""" """
_tts_invoke_streaming text2speech model _tts_invoke_streaming text2speech model
:param model: model name :param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials :param credentials: model credentials
:param content_text: text content to be translated :param content_text: text content to be translated
:param voice: model timbre :param voice: model timbre
:return: text translated to audio file :return: text translated to audio file
""" """
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
try: try:
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = OpenAI(**credentials_kwargs) client = OpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
for sentence in sentences: voice = self._get_model_default_voice(model, credentials)
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) word_limit = self._get_model_word_limit(model, credentials)
# response.stream_to_file(file_path) if len(content_text) > word_limit:
storage.save(file_path, response.read()) sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
response_format="mp3",
input=sentences[i], voice=voice) for i in range(len(sentences))]
for index, future in enumerate(futures):
yield from future.result().__enter__().iter_bytes(1024)
else:
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
response_format="mp3",
input=content_text.strip())
yield from response.__enter__().iter_bytes(1024)
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))

View File

@ -616,30 +616,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls: if message.tool_calls:
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call function_calling_type = credentials.get('function_calling_type', 'no_call')
# in if function_calling_type == 'tool_call':
# message.tool_calls] message_dict["tool_calls"] = [tool_call.dict() for tool_call in
message.tool_calls]
function_call = message.tool_calls[0] elif function_calling_type == 'function_call':
message_dict["function_call"] = { function_call = message.tool_calls[0]
"name": function_call.function.name, message_dict["function_call"] = {
"arguments": function_call.function.arguments, "name": function_call.function.name,
} "arguments": function_call.function.arguments,
}
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message) message = cast(ToolPromptMessage, message)
# message_dict = { function_calling_type = credentials.get('function_calling_type', 'no_call')
# "role": "tool", if function_calling_type == 'tool_call':
# "content": message.content, message_dict = {
# "tool_call_id": message.tool_call_id "role": "tool",
# } "content": message.content,
message_dict = { "tool_call_id": message.tool_call_id
"role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function", }
"content": message.content, elif function_calling_type == 'function_call':
"name": message.tool_call_id message_dict = {
} "role": "function",
"content": message.content,
"name": message.tool_call_id
}
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 21 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 48 KiB

View File

@ -0,0 +1,61 @@
model: Qwen-14B-Chat-Int4
label:
en_US: Qwen-14B-Chat-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -0,0 +1,61 @@
model: Qwen1.5-110B-Chat-GPTQ-Int4
label:
en_US: Qwen1.5-110B-Chat-GPTQ-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 128
min: 1
max: 256
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -0,0 +1,61 @@
model: Qwen1.5-72B-Chat-GPTQ-Int4
label:
en_US: Qwen1.5-72B-Chat-GPTQ-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -0,0 +1,61 @@
model: Qwen1.5-7B
label:
en_US: Qwen1.5-7B
model_type: llm
features:
- agent-thought
model_properties:
mode: completion
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -0,0 +1,63 @@
model: Qwen2-72B-Instruct-GPTQ-Int4
label:
en_US: Qwen2-72B-Instruct-GPTQ-Int4
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -0,0 +1,63 @@
model: Qwen2-7B
label:
en_US: Qwen2-7B
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: completion
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
currency: RMB

View File

@ -0,0 +1,6 @@
- Qwen2-72B-Instruct-GPTQ-Int4
- Qwen2-7B
- Qwen1.5-110B-Chat-GPTQ-Int4
- Qwen1.5-72B-Chat-GPTQ-Int4
- Qwen1.5-7B
- Qwen-14B-Chat-Int4

View File

@ -0,0 +1,110 @@
from collections.abc import Generator
from typing import Optional, Union
from urllib.parse import urlparse
import tiktoken
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
# refactored from openai model runtime, use cl100k_base for calculate token number
def _num_tokens_from_string(self, model: str, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
:param model: model name
:param text: prompt text
:param tools: tools for tool calling
:return: number of tokens
"""
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(text))
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
# refactored from openai model runtime, use cl100k_base for calculate token number
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
value = text
if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(t_key))
num_tokens += len(encoding.encode(t_value))
else:
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['openai_api_key']=credentials['api_key']
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
credentials['openai_api_base']='https://cloud.perfxlab.cn'
else:
parsed_url = urlparse(credentials['endpoint_url'])
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"

View File

@ -0,0 +1,32 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class PerfXCloudProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate,
# no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(
model='Qwen2-72B-Instruct-GPTQ-Int4',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -0,0 +1,42 @@
provider: perfxcloud
label:
en_US: PerfXCloud
zh_Hans: PerfXCloud
description:
en_US: PerfXCloud (Pengfeng Technology) is an AI development and deployment platform tailored for developers and enterprises, providing reasoning capabilities for multiple models.
zh_Hans: PerfXCloud澎峰科技为开发者和企业量身打造的AI开发和部署平台提供多种模型的的推理能力。
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#e3f0ff"
help:
title:
en_US: Get your API Key from PerfXCloud
zh_Hans: 从 PerfXCloud 获取 API Key
url:
en_US: https://cloud.perfxlab.cn/panel/token
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
zh_Hans: 自定义 API endpoint 地址
en_US: Custom API endpoint URL
type: text-input
required: false
placeholder:
zh_Hans: Base URL, e.g. https://cloud.perfxlab.cn/v1
en_US: Base URL, e.g. https://cloud.perfxlab.cn/v1

View File

@ -0,0 +1,4 @@
model: BAAI/bge-m3
model_type: text-embedding
model_properties:
context_size: 32768

View File

@ -0,0 +1,250 @@
import json
import time
from decimal import Decimal
from typing import Optional
from urllib.parse import urljoin
import numpy as np
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
PriceConfig,
PriceType,
)
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
"""
Model class for an OpenAI API-compatible text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# Prepare headers and payload for the request
headers = {
'Content-Type': 'application/json'
}
api_key = credentials.get('api_key')
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
endpoint_url='https://cloud.perfxlab.cn/v1/'
else:
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = urljoin(endpoint_url, 'embeddings')
extra_model_kwargs = {}
if user:
extra_model_kwargs['user'] = user
extra_model_kwargs['encoding_format'] = 'float'
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
# TODO: Optimize for better token estimation and chunking
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0: cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
# Prepare the payload for the request
payload = {
'input': inputs[i: i + max_chunks],
'model': model,
**extra_model_kwargs
}
# Make the request to the OpenAI API
response = requests.post(
endpoint_url,
headers=headers,
data=json.dumps(payload),
timeout=(10, 300)
)
response.raise_for_status() # Raise an exception for HTTP errors
response_data = response.json()
# Extract embeddings and used tokens from the response
embeddings_batch = [data['embedding'] for data in response_data['data']]
embedding_used_tokens = response_data['usage']['total_tokens']
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,
model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
headers = {
'Content-Type': 'application/json'
}
api_key = credentials.get('api_key')
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
endpoint_url='https://cloud.perfxlab.cn/v1/'
else:
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = urljoin(endpoint_url, 'embeddings')
payload = {
'input': 'ping',
'model': model
}
response = requests.post(
url=endpoint_url,
headers=headers,
data=json.dumps(payload),
timeout=(10, 300)
)
if response.status_code != 200:
raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
if 'model' not in json_result:
raise CredentialsValidateFailedError(
'Credentials validation failed: invalid response')
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)
return entity
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage

View File

@ -129,7 +129,7 @@ model_properties:
- mode: "sambert-waan-v1" - mode: "sambert-waan-v1"
name: "Waan泰语女声" name: "Waan泰语女声"
language: [ "th-TH" ] language: [ "th-TH" ]
word_limit: 120 word_limit: 7000
audio_type: 'mp3' audio_type: 'mp3'
max_workers: 5 max_workers: 5
pricing: pricing:

View File

@ -1,17 +1,21 @@
import concurrent.futures import concurrent.futures
import threading
from functools import reduce from functools import reduce
from io import BytesIO from io import BytesIO
from queue import Queue
from typing import Optional from typing import Optional
import dashscope import dashscope
from flask import Response, stream_with_context from dashscope import SpeechSynthesizer
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult
from flask import Response
from pydub import AudioSegment from pydub import AudioSegment
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
from extensions.ext_storage import storage
class TongyiText2SpeechModel(_CommonTongyi, TTSModel): class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
@ -19,7 +23,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
Model class for Tongyi Speech to text model. Model class for Tongyi Speech to text model.
""" """
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None) -> any: user: Optional[str] = None) -> any:
""" """
_invoke text2speech model _invoke text2speech model
@ -29,22 +33,17 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
:param credentials: model credentials :param credentials: model credentials
:param voice: model timbre :param voice: model timbre
:param content_text: text content to be translated :param content_text: text content to be translated
:param streaming: output is streaming
:param user: unique user id :param user: unique user id
:return: text translated to audio file :return: text translated to audio file
""" """
audio_type = self._get_model_audio_type(model, credentials) if not voice or voice not in [d['value'] for d in
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials) voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model, return self._tts_invoke_streaming(model=model,
credentials=credentials, credentials=credentials,
content_text=content_text, content_text=content_text,
voice=voice, voice=voice)
tenant_id=tenant_id)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
""" """
@ -79,7 +78,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
word_limit = self._get_model_word_limit(model, credentials) word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials)
try: try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
audio_bytes_list = [] audio_bytes_list = []
# Create a thread pool and map the function to the list of sentences # Create a thread pool and map the function to the list of sentences
@ -105,14 +104,12 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
voice: str) -> any: voice: str) -> any:
""" """
_tts_invoke_streaming text2speech model _tts_invoke_streaming text2speech model
:param model: model name :param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials :param credentials: model credentials
:param voice: model timbre :param voice: model timbre
:param content_text: text content to be translated :param content_text: text content to be translated
@ -120,18 +117,32 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
""" """
word_limit = self._get_model_word_limit(model, credentials) word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials) audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
try: try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) audio_queue: Queue = Queue()
for sentence in sentences: callback = Callback(queue=audio_queue)
response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000,
api_key=credentials.get('dashscope_api_key'), def invoke_remote(content, v, api_key, cb, at, wl):
text=sentence.strip(), if len(content) < word_limit:
format=audio_type, word_timestamp_enabled=True, sentences = [content]
phoneme_timestamp_enabled=True) else:
if isinstance(response.get_audio_data(), bytes): sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl))
storage.save(file_path, response.get_audio_data()) for sentence in sentences:
SpeechSynthesizer.call(model=v, sample_rate=16000,
api_key=api_key,
text=sentence.strip(),
callback=cb,
format=at, word_timestamp_enabled=True,
phoneme_timestamp_enabled=True)
threading.Thread(target=invoke_remote, args=(
content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start()
while True:
audio = audio_queue.get()
if audio is None:
break
yield audio
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))
@ -152,3 +163,29 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
format=audio_type) format=audio_type)
if isinstance(response.get_audio_data(), bytes): if isinstance(response.get_audio_data(), bytes):
return response.get_audio_data() return response.get_audio_data()
class Callback(ResultCallback):
def __init__(self, queue: Queue):
self._queue = queue
def on_open(self):
pass
def on_complete(self):
self._queue.put(None)
self._queue.task_done()
def on_error(self, response: SpeechSynthesisResponse):
self._queue.put(None)
self._queue.task_done()
def on_close(self):
self._queue.put(None)
self._queue.task_done()
def on_event(self, result: SpeechSynthesisResult):
ad = result.get_audio_frame()
if ad:
self._queue.put(ad)

View File

@ -29,7 +29,7 @@ model_credential_schema:
label: label:
zh_Hans: 服务器URL zh_Hans: 服务器URL
en_US: Server url en_US: Server url
type: secret-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000 zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000

View File

@ -0,0 +1,40 @@
model: ernie-4.0-turbo-8k-preview
label:
en_US: Ernie-4.0-turbo-8k-preview
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 2048
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
en_US: Disable Search
type: boolean
help:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false

View File

@ -138,6 +138,7 @@ class ErnieBotModel:
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
} }
function_calling_supports = [ function_calling_supports = [
@ -149,6 +150,7 @@ class ErnieBotModel:
'ernie-3.5-4k-0205', 'ernie-3.5-4k-0205',
'ernie-3.5-128k', 'ernie-3.5-128k',
'ernie-4.0-8k' 'ernie-4.0-8k'
'ernie-4.0-turbo-8k-preview'
] ]
api_key: str = '' api_key: str = ''

View File

@ -32,7 +32,7 @@ model_credential_schema:
label: label:
zh_Hans: 服务器URL zh_Hans: 服务器URL
en_US: Server url en_US: Server url
type: secret-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 http://192.168.1.100:9997 zh_Hans: 在此输入Xinference的服务器地址如 http://192.168.1.100:9997

View File

@ -20,7 +20,7 @@ class ZhipuaiProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM) model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials( model_instance.validate_credentials(
model='chatglm_turbo', model='glm-4',
credentials=credentials credentials=credentials
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
import httpx import httpx
from ..core._base_api import BaseAPI from ..core._base_api import BaseAPI
from ..core._base_type import NOT_GIVEN, Headers, NotGiven from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven
from ..core._http_client import make_user_request_input from ..core._http_client import make_user_request_input
from ..types.image import ImagesResponded from ..types.image import ImagesResponded
@ -28,7 +28,9 @@ class Images(BaseAPI):
size: Optional[str] | NotGiven = NOT_GIVEN, size: Optional[str] | NotGiven = NOT_GIVEN,
style: Optional[str] | NotGiven = NOT_GIVEN, style: Optional[str] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN, user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None, extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None, disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ImagesResponded: ) -> ImagesResponded:
@ -46,9 +48,12 @@ class Images(BaseAPI):
"size": size, "size": size,
"style": style, "style": style,
"user": user, "user": user,
"request_id": request_id,
}, },
options=make_user_request_input( options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout
), ),
cast_type=_cast_type, cast_type=_cast_type,
enable_stream=False, enable_stream=False,

View File

@ -11,7 +11,7 @@ from tenacity import retry
from tenacity.stop import stop_after_attempt from tenacity.stop import stop_after_attempt
from . import _errors from . import _errors
from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
from ._files import make_httpx_files from ._files import make_httpx_files
from ._request_opt import ClientRequestParam, UserRequestInput from ._request_opt import ClientRequestParam, UserRequestInput
@ -358,6 +358,7 @@ def make_user_request_input(
max_retries: int | None = None, max_retries: int | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN, timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
extra_headers: Headers = None, extra_headers: Headers = None,
extra_body: Body | None = None,
query: Query | None = None, query: Query | None = None,
) -> UserRequestInput: ) -> UserRequestInput:
options: UserRequestInput = {} options: UserRequestInput = {}
@ -370,5 +371,7 @@ def make_user_request_input(
options['timeout'] = timeout options['timeout'] = timeout
if query is not None: if query is not None:
options["params"] = query options["params"] = query
if extra_body is not None:
options["extra_json"] = cast(AnyMapping, extra_body)
return options return options

View File

@ -1,7 +1,6 @@
from typing import Any from typing import Any
from flask import current_app from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba import Jieba from core.rag.datasource.keyword.jieba.jieba import Jieba
from core.rag.datasource.keyword.keyword_base import BaseKeyword from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document from core.rag.models.document import Document
@ -14,8 +13,8 @@ class Keyword:
self._keyword_processor = self._init_keyword() self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword: def _init_keyword(self) -> BaseKeyword:
config = current_app.config config = dify_config
keyword_type = config.get('KEYWORD_STORE') keyword_type = config.KEYWORD_STORE
if not keyword_type: if not keyword_type:
raise ValueError("Keyword store must be specified.") raise ValueError("Keyword store must be specified.")

View File

@ -0,0 +1,332 @@
import json
from typing import Any
from pydantic import BaseModel
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from flask import current_app
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class AnalyticdbConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = ("dify",)
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
_instance = None
_init = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, collection_name: str, config: AnalyticdbConfig):
# collection_name must be updated every time
self._collection_name = collection_name.lower()
if AnalyticdbVector._init:
return
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(
user_agent="dify", **config.to_analyticdb_client_params()
)
self._client = Client(self._client_config)
self._initialize()
AnalyticdbVector._init = True
def _initialize(self) -> None:
self._initialize_vector_database()
self._create_namespace_if_not_exists()
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(
f"failed to create namespace {self.config.namespace}: {e}"
)
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(
f"failed to create collection {self._collection_name}: {e}"
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def get_type(self) -> str:
return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings)
def add_texts(
self, documents: list[Document], embeddings: list[list[float]], **kwargs
):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'"
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
def search_by_vector(
self, query_vector: list[float], **kwargs: Any
) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
kwargs.get("score_threshold", 0.0)
if kwargs.get("score_threshold", 0.0)
else 0.0
)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=json.loads(match.metadata.get("metadata_")),
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
kwargs.get("score_threshold", 0.0)
if kwargs.get("score_threshold", 0.0)
else 0.0
)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=json.loads(match.metadata.get("metadata_")),
)
documents.append(doc)
return documents
def delete(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"][
"class_prefix"
]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
)
config = current_app.config
return AnalyticdbVector(
collection_name,
AnalyticdbConfig(
access_key_id=config.get("ANALYTICDB_KEY_ID"),
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
region_id=config.get("ANALYTICDB_REGION_ID"),
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
account=config.get("ANALYTICDB_ACCOUNT"),
account_password=config.get("ANALYTICDB_PASSWORD"),
namespace=config.get("ANALYTICDB_NAMESPACE"),
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
),
)

View File

@ -0,0 +1,170 @@
import json
import logging
import uuid
from enum import Enum
from typing import Any
from clickhouse_connect import get_client
from flask import current_app
from pydantic import BaseModel
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from models.dataset import Dataset
class MyScaleConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
fts_params: str
class SortOrder(Enum):
ASC = "ASC"
DESC = "DESC"
class MyScaleVector(BaseVector):
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
super().__init__(collection_name)
self._config = config
self._metric = metric
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
self._client = get_client(
host=config.host,
port=config.port,
username=config.user,
password=config.password,
)
self._client.command("SET allow_experimental_object_type=1")
def get_type(self) -> str:
return VectorType.MYSCALE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
def _create_collection(self, dimension: int):
logging.info(f"create MyScale collection {self._collection_name} with dimension {dimension}")
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
sql = f"""
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
id String,
text String,
vector Array(Float32),
metadata JSON,
CONSTRAINT cons_vec_len CHECK length(vector) = {dimension},
VECTOR INDEX vidx vector TYPE DEFAULT('metric_type = {self._metric}'),
INDEX text_idx text TYPE fts{fts_params}
) ENGINE = MergeTree ORDER BY id
"""
self._client.command(sql)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = []
columns = ["id", "text", "vector", "metadata"]
values = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
row = (
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {}
)
values.append(str(row))
ids.append(doc_id)
sql = f"""
INSERT INTO {self._config.database}.{self._collection_name}
({",".join(columns)}) VALUES {",".join(values)}
"""
self._client.command(sql)
return ids
@staticmethod
def escape_str(value: Any) -> str:
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value))
def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
return results.row_count > 0
def delete_by_ids(self, ids: list[str]) -> None:
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
def get_ids_by_metadata_field(self, key: str, value: str):
rows = self._client.query(
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
).result_rows
return [row[0] for row in rows]
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs)
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold", 0.0)
where_str = f"WHERE dist < {1 - score_threshold}" if \
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
sql = f"""
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
"""
try:
return [
Document(
page_content=r["text"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
]
except Exception as e:
logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
def delete(self) -> None:
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
class MyScaleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
config = current_app.config
return MyScaleVector(
collection_name=collection_name,
config=MyScaleConfig(
host=config.get("MYSCALE_HOST", "localhost"),
port=int(config.get("MYSCALE_PORT", 8123)),
user=config.get("MYSCALE_USER", "default"),
password=config.get("MYSCALE_PASSWORD", ""),
database=config.get("MYSCALE_DATABASE", "default"),
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
),
)

View File

@ -57,6 +57,9 @@ class Vector:
case VectorType.MILVUS: case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory return MilvusVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR: case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory return PGVectorFactory
@ -84,6 +87,9 @@ class Vector:
case VectorType.OPENSEARCH: case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case _: case _:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -2,8 +2,10 @@ from enum import Enum
class VectorType(str, Enum): class VectorType(str, Enum):
ANALYTICDB = 'analyticdb'
CHROMA = 'chroma' CHROMA = 'chroma'
MILVUS = 'milvus' MILVUS = 'milvus'
MYSCALE = 'myscale'
PGVECTOR = 'pgvector' PGVECTOR = 'pgvector'
PGVECTO_RS = 'pgvecto-rs' PGVECTO_RS = 'pgvecto-rs'
QDRANT = 'qdrant' QDRANT = 'qdrant'

View File

@ -46,7 +46,6 @@ class FirecrawlApp:
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}') raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
def crawl_url(self, url, params=None) -> str: def crawl_url(self, url, params=None) -> str:
start_time = time.time()
headers = self._prepare_headers() headers = self._prepare_headers()
json_data = {'url': url} json_data = {'url': url}
if params: if params:

View File

@ -18,8 +18,8 @@ class MarkdownExtractor(BaseExtractor):
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
remove_hyperlinks: bool = True, remove_hyperlinks: bool = False,
remove_images: bool = True, remove_images: bool = False,
encoding: Optional[str] = None, encoding: Optional[str] = None,
autodetect_encoding: bool = True, autodetect_encoding: bool = True,
): ):

View File

@ -8,7 +8,7 @@ We have defined a series of helper methods in the `Tool` class to help developer
### Message Return ### Message Return
Dify supports various message types such as `text`, `link`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces. Dify supports various message types such as `text`, `link`, `json`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
Please note, some parameters in the following interfaces will be introduced in later sections. Please note, some parameters in the following interfaces will be introduced in later sections.
@ -67,6 +67,18 @@ If you need to return the raw data of a file, such as images, audio, video, PPT,
""" """
``` ```
#### JSON
If you need to return a formatted JSON, you can use the following interface. This is commonly used for data transmission between nodes in a workflow, of course, in agent mode, most LLM are also able to read and understand JSON.
- `object` A Python dictionary object will be automatically serialized into JSON
```python
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
```
### Shortcut Tools ### Shortcut Tools
In large model applications, we have two common needs: In large model applications, we have two common needs:

View File

@ -145,19 +145,25 @@ parameters: # Parameter list
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc. - The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
- `parameters` Parameter list - `parameters` Parameter list
- `name` Parameter name, unique, no duplication with other parameters - `name` (Mandatory) Parameter name, must be unique and not duplicate with other parameters.
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type - `type` (Mandatory) Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` five types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using the `secret-input` type
- `required` Required or not - `label` (Mandatory) Parameter label, for frontend display
- `form` (Mandatory) Form type, currently supports `llm`, `form` two types.
- In an agent app, `llm` indicates that the parameter is inferred by the LLM itself, while `form` indicates that the parameter can be pre-set for the tool.
- In a workflow app, both `llm` and `form` need to be filled out by the front end, but the parameters of `llm` will be used as input variables for the tool node.
- `required` Indicates whether the parameter is required or not
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter - In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts - In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
- `options` Parameter options - `options` Parameter options
- In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options - In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options
- In `form` mode, when `type` is `select`, the frontend will display these options - In `form` mode, when `type` is `select`, the frontend will display these options
- `default` Default value - `default` Default value
- `label` Parameter label, for frontend display - `min` Minimum value, can be set when the parameter type is `number`.
- `max` Maximum value, can be set when the parameter type is `number`.
- `placeholder` The prompt text for input boxes. It can be set when the form type is `form`, and the parameter type is `string`, `number`, or `secret-input`. It supports multiple languages.
- `human_description` Introduction for frontend display, supports multiple languages - `human_description` Introduction for frontend display, supports multiple languages
- `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter - `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
- `form` Form type, currently supports `llm`, `form` two types, corresponding to Agent self-inference and frontend filling
## 4. Add Tool Logic ## 4. Add Tool Logic
@ -196,7 +202,7 @@ The overall logic of the tool is in the `_invoke` method, this method accepts tw
### Return Data ### Return Data
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message. When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message. If you want to return multiple messages, you can use `[self.create_text_message('msg1'), self.create_text_message('msg2')]` to create a list of messages.
## 5. Add Provider Code ## 5. Add Provider Code
@ -205,8 +211,6 @@ Finally, we need to create a provider class under the provider module to impleme
Create `google.py` under the `google` module, the content is as follows. Create `google.py` under the `google` module, the content is as follows.
```python ```python
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool.tool import Tool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError

View File

@ -8,7 +8,7 @@
### 消息返回 ### 消息返回
Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型你可以通过以下几个接口返回不同类型的消息给LLM和用户。 Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型你可以通过以下几个接口返回不同类型的消息给LLM和用户。
注意,在下面的接口中的部分参数将在后面的章节中介绍。 注意,在下面的接口中的部分参数将在后面的章节中介绍。
@ -67,6 +67,18 @@ Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可
""" """
``` ```
#### JSON
如果你需要返回一个格式化的JSON可以使用以下接口。这通常用于workflow中的节点间的数据传递当然agent模式中大部分大模型也都能够阅读和理解JSON。
- `object` 一个Python的字典对象会被自动序列化为JSON
```python
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
```
### 快捷工具 ### 快捷工具
在大模型应用中,我们有两种常见的需求: 在大模型应用中,我们有两种常见的需求:
@ -97,8 +109,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可
```python ```python
def get_url(self, url: str, user_agent: str = None) -> str: def get_url(self, url: str, user_agent: str = None) -> str:
""" """
get url get url from the crawled result
""" the crawled result """
``` ```
### 变量池 ### 变量池

View File

@ -140,8 +140,12 @@ parameters: # 参数列表
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等 - `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
- `parameters` 参数列表 - `parameters` 参数列表
- `name` 参数名称,唯一,不允许和其他参数重名 - `name` (必填)参数名称,唯一,不允许和其他参数重名
- `type` 参数类型,目前支持`string``number``boolean``select``secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型 - `type` (必填)参数类型,目前支持`string``number``boolean``select``secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
- `label`(必填)参数标签,用于前端展示
- `form` (必填)表单类型,目前支持`llm``form`两种类型
- 在Agent应用中`llm`表示该参数LLM自行推理`form`表示要使用该工具可提前设定的参数
- 在workflow应用中`llm``form`均需要前端填写,但`llm`的参数会做为工具节点的输入变量
- `required` 是否必填 - `required` 是否必填
- 在`llm`模式下如果参数为必填则会要求Agent必须要推理出这个参数 - 在`llm`模式下如果参数为必填则会要求Agent必须要推理出这个参数
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数 - 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
@ -149,10 +153,12 @@ parameters: # 参数列表
- 在`llm`模式下Dify会将所有选项传递给LLMLLM可以根据这些选项进行推理 - 在`llm`模式下Dify会将所有选项传递给LLMLLM可以根据这些选项进行推理
- 在`form`模式下,`type``select`时,前端会展示这些选项 - 在`form`模式下,`type``select`时,前端会展示这些选项
- `default` 默认值 - `default` 默认值
- `label` 参数标签,用于前端展示 - `min` 最小值,当参数类型为`number`时可以设定
- `max` 最大值,当参数类型为`number`时可以设定
- `human_description` 用于前端展示的介绍,支持多语言 - `human_description` 用于前端展示的介绍,支持多语言
- `placeholder` 字段输入框的提示文字,在表单类型为`form`,参数类型为`string``number``secret-input`时,可以设定,支持多语言
- `llm_description` 传递给LLM的介绍为了使得LLM更好理解这个参数我们建议在这里写上关于这个参数尽可能详细的信息让LLM能够理解这个参数 - `llm_description` 传递给LLM的介绍为了使得LLM更好理解这个参数我们建议在这里写上关于这个参数尽可能详细的信息让LLM能够理解这个参数
- `form` 表单类型,目前支持`llm``form`两种类型分别对应Agent自行推理和前端填写
## 4. 准备工具代码 ## 4. 准备工具代码
当完成工具的配置以后,我们就可以开始编写工具代码了,主要用于实现工具的逻辑。 当完成工具的配置以后,我们就可以开始编写工具代码了,主要用于实现工具的逻辑。
@ -176,7 +182,6 @@ class GoogleSearchTool(BuiltinTool):
query = tool_parameters['query'] query = tool_parameters['query']
result_type = tool_parameters['result_type'] result_type = tool_parameters['result_type']
api_key = self.runtime.credentials['serpapi_api_key'] api_key = self.runtime.credentials['serpapi_api_key']
# TODO: search with serpapi
result = SerpAPI(api_key).run(query, result_type=result_type) result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text': if result_type == 'text':
@ -188,7 +193,7 @@ class GoogleSearchTool(BuiltinTool):
工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id``tool_parameters`分别表示用户ID和工具参数 工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id``tool_parameters`分别表示用户ID和工具参数
### 返回数据 ### 返回数据
在工具返回时,你可以选择返回一个消息或者多个消息,这里我们返回一个消息,使用`create_text_message``create_link_message`可以创建一个文本消息或者一个链接消息。 在工具返回时,你可以选择返回一条消息或者多个消息,这里我们返回一条消息,使用`create_text_message``create_link_message`可以创建一条文本消息或者一条链接消息。如需返回多条消息,可以使用列表构建,例如`[self.create_text_message('msg1'), self.create_text_message('msg2')]`
## 5. 准备供应商代码 ## 5. 准备供应商代码
最后,我们需要在供应商模块下创建一个供应商类,用于实现供应商的凭据验证逻辑,如果凭据验证失败,将会抛出`ToolProviderCredentialValidationError`异常。 最后,我们需要在供应商模块下创建一个供应商类,用于实现供应商的凭据验证逻辑,如果凭据验证失败,将会抛出`ToolProviderCredentialValidationError`异常。
@ -196,8 +201,6 @@ class GoogleSearchTool(BuiltinTool):
`google`模块下创建`google.py`,内容如下。 `google`模块下创建`google.py`,内容如下。
```python ```python
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool.tool import Tool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError

View File

@ -142,7 +142,8 @@ class ToolParameter(BaseModel):
name: str = Field(..., description="The name of the parameter") name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user") label: I18nObject = Field(..., description="The label presented to the user")
human_description: I18nObject = Field(..., description="The description presented to the user") human_description: Optional[I18nObject] = Field(None, description="The description presented to the user")
placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user")
type: ToolParameterType = Field(..., description="The type of the parameter") type: ToolParameterType = Field(..., description="The type of the parameter")
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None llm_description: Optional[str] = None

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Some files were not shown because too many files have changed in this diff Show More