mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 10:05:59 +08:00
merge
This commit is contained in:
commit
775e52db4d
@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd web && npm install
|
||||
pipx install poetry
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
|
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@ -75,7 +75,7 @@ jobs:
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
@ -89,5 +89,6 @@ jobs:
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
myscale
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
12
.github/workflows/build-push.yml
vendored
12
.github/workflows/build-push.yml
vendored
@ -48,18 +48,18 @@ jobs:
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
|
1
.vscode/launch.json
vendored
1
.vscode/launch.json
vendored
@ -13,7 +13,6 @@
|
||||
"jinja": true,
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
|
@ -2,17 +2,17 @@
|
||||
|
||||
考虑到我们的现状,我们需要灵活快速地交付,但我们也希望确保像你这样的贡献者在贡献过程中获得尽可能顺畅的体验。我们为此编写了这份贡献指南,旨在让你熟悉代码库和我们与贡献者的合作方式,以便你能快速进入有趣的部分。
|
||||
|
||||
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎任何反馈以供我们改进。
|
||||
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎提供任何反馈以供我们改进。
|
||||
|
||||
在许可方面,请花一分钟阅读我们简短的[许可证和贡献者协议](./LICENSE)。社区还遵守[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
在许可方面,请花一分钟阅读我们简短的 [许可证和贡献者协议](./LICENSE)。社区还遵守 [行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
|
||||
## 在开始之前
|
||||
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或[创建](https://github.com/langgenius/dify/issues/new/choose)一个新问题。我们将问题分为两类:
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
|
||||
### 功能请求:
|
||||
|
||||
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss)制作了一个很好的[功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
|
||||
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss) 制作了一个很好的 [功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
|
||||
|
||||
* 如果您想从现有问题中选择一个,请在其下方留下评论表示您的意愿。
|
||||
|
||||
@ -20,45 +20,44 @@
|
||||
|
||||
根据所提议的功能所属的领域不同,您可能需要与不同的团队成员交流。以下是我们团队成员目前正在从事的各个领域的概述:
|
||||
|
||||
| Member | Scope |
|
||||
| 团队成员 | 工作范围 |
|
||||
| ------------------------------------------------------------ | ---------------------------------------------------- |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | Architecting Agents |
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG pipeline design |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | Building workflow orchestrations |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Making our frontend a breeze to use |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Developer experience, points of contact for anything |
|
||||
| [@takatost](https://github.com/takatost) | Overall product direction and architecture |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | 架构 Agents |
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG 流水线设计 |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | 构建 workflow 编排 |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | 让我们的前端更易用 |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验, 综合事项联系人 |
|
||||
| [@takatost](https://github.com/takatost) | 产品整体方向和架构 |
|
||||
|
||||
How we prioritize:
|
||||
事项优先级:
|
||||
|
||||
| Feature Type | Priority |
|
||||
| 功能类型 | 优先级 |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| High-Priority Features as being labeled by a team member | High Priority |
|
||||
| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority |
|
||||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
| 被团队成员标记为高优先级的功能 | 高优先级 |
|
||||
| 在 [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 |
|
||||
| 非核心功能和小幅改进 | 低优先级 |
|
||||
| 有价值当不紧急 | 未来功能 |
|
||||
|
||||
### 其他任何事情(例如bug报告、性能优化、拼写错误更正):
|
||||
### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正):
|
||||
* 立即开始编码。
|
||||
|
||||
How we prioritize:
|
||||
事项优先级:
|
||||
|
||||
| Issue Type | Priority |
|
||||
| Issue 类型 | 优先级 |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| Bugs in core functions (cannot login, applications not working, security loopholes) | Critical |
|
||||
| Non-critical bugs, performance boosts | Medium Priority |
|
||||
| Minor fixes (typos, confusing but working UI) | Low Priority |
|
||||
|
||||
| 核心功能的 Bugs(例如无法登录、应用无法工作、安全漏洞) | 紧急 |
|
||||
| 非紧急 bugs, 性能提升 | 中等优先级 |
|
||||
| 小幅修复(错别字, 能正常工作但存在误导的 UI) | 低优先级 |
|
||||
|
||||
## 安装
|
||||
|
||||
以下是设置Dify进行开发的步骤:
|
||||
以下是设置 Dify 进行开发的步骤:
|
||||
|
||||
### 1. Fork该仓库
|
||||
### 1. Fork 该仓库
|
||||
|
||||
### 2. 克隆仓库
|
||||
|
||||
从终端克隆fork的仓库:
|
||||
从终端克隆代码仓库:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
@ -76,72 +75,72 @@ Dify 依赖以下工具和库:
|
||||
|
||||
### 4. 安装
|
||||
|
||||
Dify由后端和前端组成。通过`cd api/`导航到后端目录,然后按照[后端README](api/README.md)进行安装。在另一个终端中,通过`cd web/`导航到前端目录,然后按照[前端README](web/README.md)进行安装。
|
||||
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
|
||||
|
||||
查看[安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq)以获取常见问题列表和故障排除步骤。
|
||||
查看 [安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq) 以获取常见问题列表和故障排除步骤。
|
||||
|
||||
### 5. 在浏览器中访问Dify
|
||||
### 5. 在浏览器中访问 Dify
|
||||
|
||||
为了验证您的设置,打开浏览器并访问[http://localhost:3000](http://localhost:3000)(默认或您自定义的URL和端口)。现在您应该看到Dify正在运行。
|
||||
为了验证您的设置,打开浏览器并访问 [http://localhost:3000](http://localhost:3000)(默认或您自定义的 URL 和端口)。现在您应该看到 Dify 正在运行。
|
||||
|
||||
## 开发
|
||||
|
||||
如果您要添加模型提供程序,请参考[此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
|
||||
如果您要添加模型提供程序,请参考 [此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
|
||||
|
||||
如果您要向Agent或Workflow添加工具提供程序,请参考[此指南](./api/core/tools/README.md)。
|
||||
如果您要向 Agent 或 Workflow 添加工具提供程序,请参考 [此指南](./api/core/tools/README.md)。
|
||||
|
||||
为了帮助您快速了解您的贡献在哪个部分,以下是Dify后端和前端的简要注释大纲:
|
||||
为了帮助您快速了解您的贡献在哪个部分,以下是 Dify 后端和前端的简要注释大纲:
|
||||
|
||||
### 后端
|
||||
|
||||
Dify的后端使用Python编写,使用[Flask](https://flask.palletsprojects.com/en/3.0.x/)框架。它使用[SQLAlchemy](https://www.sqlalchemy.org/)作为ORM,使用[Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)作为任务队列。授权逻辑通过Flask-login进行处理。
|
||||
Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsprojects.com/en/3.0.x/) 框架。它使用 [SQLAlchemy](https://www.sqlalchemy.org/) 作为 ORM,使用 [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) 作为任务队列。授权逻辑通过 Flask-login 进行处理。
|
||||
|
||||
```
|
||||
[api/]
|
||||
├── constants // Constant settings used throughout code base.
|
||||
├── controllers // API route definitions and request handling logic.
|
||||
├── core // Core application orchestration, model integrations, and tools.
|
||||
├── docker // Docker & containerization related configurations.
|
||||
├── events // Event handling and processing
|
||||
├── extensions // Extensions with 3rd party frameworks/platforms.
|
||||
├── fields // field definitions for serialization/marshalling.
|
||||
├── libs // Reusable libraries and helpers.
|
||||
├── migrations // Scripts for database migration.
|
||||
├── models // Database models & schema definitions.
|
||||
├── services // Specifies business logic.
|
||||
├── storage // Private key storage.
|
||||
├── tasks // Handling of async tasks and background jobs.
|
||||
├── constants // 用于整个代码库的常量设置。
|
||||
├── controllers // API 路由定义和请求处理逻辑。
|
||||
├── core // 核心应用编排、模型集成和工具。
|
||||
├── docker // Docker 和容器化相关配置。
|
||||
├── events // 事件处理和处理。
|
||||
├── extensions // 与第三方框架/平台的扩展。
|
||||
├── fields // 用于序列化/封装的字段定义。
|
||||
├── libs // 可重用的库和助手。
|
||||
├── migrations // 数据库迁移脚本。
|
||||
├── models // 数据库模型和架构定义。
|
||||
├── services // 指定业务逻辑。
|
||||
├── storage // 私钥存储。
|
||||
├── tasks // 异步任务和后台作业的处理。
|
||||
└── tests
|
||||
```
|
||||
|
||||
### 前端
|
||||
|
||||
该网站使用基于Typescript的[Next.js](https://nextjs.org/)模板进行引导,并使用[Tailwind CSS](https://tailwindcss.com/)进行样式设计。[React-i18next](https://react.i18next.com/)用于国际化。
|
||||
该网站使用基于 Typescript 的 [Next.js](https://nextjs.org/) 模板进行引导,并使用 [Tailwind CSS](https://tailwindcss.com/) 进行样式设计。[React-i18next](https://react.i18next.com/) 用于国际化。
|
||||
|
||||
```
|
||||
[web/]
|
||||
├── app // layouts, pages, and components
|
||||
│ ├── (commonLayout) // common layout used throughout the app
|
||||
│ ├── (shareLayout) // layouts specifically shared across token-specific sessions
|
||||
│ ├── activate // activate page
|
||||
│ ├── components // shared by pages and layouts
|
||||
│ ├── install // install page
|
||||
│ ├── signin // signin page
|
||||
│ └── styles // globally shared styles
|
||||
├── assets // Static assets
|
||||
├── bin // scripts ran at build step
|
||||
├── config // adjustable settings and options
|
||||
├── context // shared contexts used by different portions of the app
|
||||
├── dictionaries // Language-specific translate files
|
||||
├── docker // container configurations
|
||||
├── hooks // Reusable hooks
|
||||
├── i18n // Internationalization configuration
|
||||
├── models // describes data models & shapes of API responses
|
||||
├── public // meta assets like favicon
|
||||
├── service // specifies shapes of API actions
|
||||
├── app // 布局、页面和组件
|
||||
│ ├── (commonLayout) // 整个应用通用的布局
|
||||
│ ├── (shareLayout) // 在特定会话中共享的布局
|
||||
│ ├── activate // 激活页面
|
||||
│ ├── components // 页面和布局共享的组件
|
||||
│ ├── install // 安装页面
|
||||
│ ├── signin // 登录页面
|
||||
│ └── styles // 全局共享的样式
|
||||
├── assets // 静态资源
|
||||
├── bin // 构建步骤运行的脚本
|
||||
├── config // 可调整的设置和选项
|
||||
├── context // 应用中不同部分使用的共享上下文
|
||||
├── dictionaries // 语言特定的翻译文件
|
||||
├── docker // 容器配置
|
||||
├── hooks // 可重用的钩子
|
||||
├── i18n // 国际化配置
|
||||
├── models // 描述数据模型和 API 响应的形状
|
||||
├── public // 如 favicon 等元资源
|
||||
├── service // 定义 API 操作的形状
|
||||
├── test
|
||||
├── types // descriptions of function params and return values
|
||||
└── utils // Shared utility functions
|
||||
├── types // 函数参数和返回值的描述
|
||||
└── utils // 共享的实用函数
|
||||
```
|
||||
|
||||
## 提交你的 PR
|
||||
|
@ -83,7 +83,7 @@ OCI_REGION=your-region
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -106,6 +106,14 @@ MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# MyScale configuration
|
||||
MYSCALE_HOST=127.0.0.1
|
||||
MYSCALE_PORT=8123
|
||||
MYSCALE_USER=default
|
||||
MYSCALE_PASSWORD=
|
||||
MYSCALE_DATABASE=default
|
||||
MYSCALE_FTS_PARAMS=
|
||||
|
||||
# Relyt configuration
|
||||
RELYT_HOST=127.0.0.1
|
||||
RELYT_PORT=5432
|
||||
@ -151,6 +159,16 @@ CHROMA_DATABASE=default_database
|
||||
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
CHROMA_AUTH_CREDENTIALS=difyai123456
|
||||
|
||||
# AnalyticDB configuration
|
||||
ANALYTICDB_KEY_ID=your-ak
|
||||
ANALYTICDB_KEY_SECRET=your-sk
|
||||
ANALYTICDB_REGION_ID=cn-hangzhou
|
||||
ANALYTICDB_INSTANCE_ID=gp-ab123456
|
||||
ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
OPENSEARCH_PORT=9200
|
||||
@ -237,4 +255,4 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
@ -337,6 +337,14 @@ def migrate_knowledge_vector_database():
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ANALYTICDB:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.ANALYTICDB,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
@ -31,6 +31,10 @@ class AppExecutionConfig(BaseSettings):
|
||||
description='execution timeout in seconds for app execution',
|
||||
default=1200,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description='max active request per app, 0 means unlimited',
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
@ -396,6 +400,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=30,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description='whether to enable dataset operator',
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
|
@ -10,8 +10,10 @@ from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorag
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
from configs.middleware.vdb.oracle_config import OracleConfig
|
||||
from configs.middleware.vdb.pgvector_config import PGVectorConfig
|
||||
@ -183,8 +185,10 @@ class MiddlewareConfig(
|
||||
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
MilvusConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
PGVectorConfig,
|
||||
|
44
api/configs/middleware/vdb/analyticdb_config.py
Normal file
44
api/configs/middleware/vdb/analyticdb_config.py
Normal file
@ -0,0 +1,44 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to AnalyticDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||
)
|
||||
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance."
|
||||
)
|
||||
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password associated with the AnalyticDB account for authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The namespace within AnalyticDB for schema isolation."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
)
|
39
api/configs/middleware/vdb/myscale_config.py
Normal file
39
api/configs/middleware/vdb/myscale_config.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
"""
|
||||
MyScale configs
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: Optional[str] = Field(
|
||||
description='MyScale host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PORT: Optional[PositiveInt] = Field(
|
||||
description='MyScale port',
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: Optional[str] = Field(
|
||||
description='MyScale user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: Optional[str] = Field(
|
||||
description='MyScale password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: Optional[str] = Field(
|
||||
description='MyScale database name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: Optional[str] = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default=None,
|
||||
)
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.6.12-fix1',
|
||||
default='0.6.14',
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
File diff suppressed because one or more lines are too long
4
api/constants/tts_auto_play_timeout.py
Normal file
4
api/constants/tts_auto_play_timeout.py
Normal file
@ -0,0 +1,4 @@
|
||||
TTS_AUTO_PLAY_TIMEOUT = 5
|
||||
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02
|
@ -15,6 +15,7 @@ from fields.app_fields import (
|
||||
app_pagination_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
@ -97,8 +98,42 @@ class AppImportApi(Resource):
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.import_app(current_user.current_tenant_id, args['data'], args, current_user)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=args['data'],
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppImportFromUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
"""Import app from url"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app_from_url(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
url=args['url'],
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
@ -134,6 +169,7 @@ class AppApi(Resource):
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('max_active_requests', type=int, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
@ -176,9 +212,13 @@ class AppCopyApi(Resource):
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
data = app_service.export_app(app_model)
|
||||
app = app_service.import_app(current_user.current_tenant_id, data, args, current_user)
|
||||
data = AppDslService.export_dsl(app_model=app_model)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=data,
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
@ -194,10 +234,8 @@ class AppExportApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
return {
|
||||
"data": app_service.export_app(app_model)
|
||||
"data": AppDslService.export_dsl(app_model=app_model)
|
||||
}
|
||||
|
||||
|
||||
@ -321,6 +359,7 @@ class AppTraceApi(Resource):
|
||||
|
||||
api.add_resource(AppListApi, '/apps')
|
||||
api.add_resource(AppImportApi, '/apps/import')
|
||||
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
|
||||
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
||||
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
|
||||
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
|
||||
|
@ -81,15 +81,36 @@ class ChatMessageTextApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'],
|
||||
streaming=False
|
||||
text=text,
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
@ -19,7 +19,12 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
@ -75,7 +80,7 @@ class CompletionMessageApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@ -141,7 +146,7 @@ class ChatMessageApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
@ -13,12 +13,14 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import App, AppMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -127,8 +129,7 @@ class DraftWorkflowImportApi(Resource):
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.import_draft_workflow(
|
||||
workflow = AppDslService.import_and_overwrite_workflow(
|
||||
app_model=app_model,
|
||||
data=args['data'],
|
||||
account=current_user
|
||||
@ -279,7 +280,7 @@ class DraftWorkflowRunApi(Resource):
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
@ -25,7 +25,7 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@ -85,6 +85,12 @@ class DatasetListApi(Resource):
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
|
||||
if item.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
|
||||
item.update({'partial_member_list': part_users_list})
|
||||
else:
|
||||
item.update({'partial_member_list': []})
|
||||
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
@ -108,8 +114,8 @@ class DatasetListApi(Resource):
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@ -140,6 +146,10 @@ class DatasetApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
@ -163,6 +173,11 @@ class DatasetApi(Resource):
|
||||
data['embedding_available'] = False
|
||||
else:
|
||||
data['embedding_available'] = True
|
||||
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
@setup_required
|
||||
@ -188,17 +203,21 @@ class DatasetApi(Resource):
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
|
||||
)
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get('permission'), data.get('partial_member_list')
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(
|
||||
dataset_id_str, args, current_user)
|
||||
@ -206,7 +225,20 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get('partial_member_list')
|
||||
)
|
||||
else:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({'partial_member_list': partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -215,11 +247,12 @@ class DatasetApi(Resource):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {'result': 'success'}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -515,7 +548,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@ -539,7 +572,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@ -569,6 +602,27 @@ class DatasetErrorDocs(Resource):
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetPermissionUserListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': partial_members_list,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
|
||||
@ -582,3 +636,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
|
||||
|
@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@ -294,6 +294,11 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -757,14 +762,18 @@ class DocumentStatusApi(DocumentResource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
# check user's permission
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
indexing_cache_key = 'document_{}_indexing'.format(document.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@ -955,10 +964,11 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
@ -19,6 +19,7 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -70,16 +71,33 @@ class ChatAudioApi(InstalledAppResource):
|
||||
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
from flask_restful import reqparse
|
||||
|
||||
app_model = installed_app.app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
@ -108,3 +126,5 @@ class ChatTextApi(InstalledAppResource):
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
||||
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
|
||||
# endpoint='installed_app_text_with_message_id')
|
||||
|
@ -36,7 +36,7 @@ class TagListApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -68,7 +68,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
def patch(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -109,8 +109,8 @@ class TagBindingCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -134,8 +134,8 @@ class TagBindingDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
@ -131,7 +131,20 @@ class MemberUpdateRoleApi(Resource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class DatasetOperatorMemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {'result': 'success', 'accounts': members}, 200
|
||||
|
||||
|
||||
api.add_resource(MemberListApi, '/workspaces/current/members')
|
||||
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
|
||||
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
|
||||
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
|
||||
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')
|
||||
|
@ -3,8 +3,9 @@ from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
from flask import abort, current_app, request
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
@ -12,12 +13,12 @@ from models.model import EndUser
|
||||
def inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
@ -28,7 +29,7 @@ def inner_api_only(view):
|
||||
def inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
if not dify_config.INNER_API:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
|
@ -1,7 +1,7 @@
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@ -78,7 +78,7 @@ class AppParameterApi(Resource):
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, EndUser
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -72,19 +72,30 @@ class AudioApi(Resource):
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=args['text'],
|
||||
end_user=end_user,
|
||||
voice=args.get('voice'),
|
||||
streaming=args['streaming']
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return response
|
||||
|
@ -17,7 +17,12 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
@ -69,7 +74,7 @@ class CompletionApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@ -132,7 +137,7 @@ class ChatApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
@ -14,7 +14,12 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
@ -59,7 +64,7 @@ class WorkflowRunApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
@ -1,6 +1,6 @@
|
||||
from flask import current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import api
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ class IndexApi(Resource):
|
||||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
"server_version": current_app.config['CURRENT_VERSION']
|
||||
"server_version": dify_config.CURRENT_VERSION,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
@ -75,7 +75,7 @@ class AppParameterApi(WebApiResource):
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,7 @@ from controllers.web.error import (
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App
|
||||
from models.model import App, AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -69,16 +69,35 @@ class AudioApi(WebApiResource):
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
from flask_restful import reqparse
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get(
|
||||
'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=request.form['voice'] if request.form.get('voice') else None,
|
||||
streaming=False
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
@ -1,8 +1,8 @@
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
@ -84,7 +84,7 @@ class AppSiteInfo:
|
||||
self.can_replace_logo = can_replace_logo
|
||||
|
||||
if can_replace_logo:
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
base_url = dify_config.FILES_URL
|
||||
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
|
||||
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
|
||||
self.custom_config = {
|
||||
|
135
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
Normal file
135
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
Normal file
@ -0,0 +1,135 @@
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class AudioTrunk:
|
||||
def __init__(self, status: str, audio):
|
||||
self.audio = audio
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def _extract_sentence(self, org_text):
|
||||
tx = self.match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
result.append(org_text[start:end])
|
||||
start = end
|
||||
return result, org_text[start:]
|
@ -255,6 +255,12 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=stopped_by),
|
||||
|
@ -4,6 +4,8 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
@ -33,6 +35,8 @@ from core.app.entities.task_entities import (
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
@ -129,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@ -182,14 +186,68 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@ -318,7 +376,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
@ -525,7 +584,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return None
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
@ -51,7 +51,6 @@ class AppQueueManager:
|
||||
listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME")
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
|
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
@ -25,6 +28,8 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
TextReplaceStreamResponse,
|
||||
@ -105,7 +110,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@ -161,8 +166,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
@ -170,6 +225,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@ -251,6 +308,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Save workflow app log.
|
||||
|
@ -69,6 +69,7 @@ class WorkflowTaskState(TaskState):
|
||||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
AdvancedChatTaskState entity
|
||||
@ -86,6 +87,8 @@ class StreamEvent(Enum):
|
||||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
TTS_MESSAGE = "tts_message"
|
||||
TTS_MESSAGE_END = "tts_message_end"
|
||||
MESSAGE_FILE = "message_file"
|
||||
MESSAGE_REPLACE = "message_replace"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
@ -130,6 +133,22 @@ class MessageStreamResponse(StreamResponse):
|
||||
answer: str
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE
|
||||
audio: str
|
||||
|
||||
|
||||
class MessageAudioEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
|
||||
audio: str
|
||||
|
||||
|
||||
class MessageEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageEndStreamResponse entity
|
||||
@ -186,6 +205,7 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -205,6 +225,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -232,6 +253,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -273,6 +295,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -323,10 +346,12 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -344,10 +369,12 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeNextStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -365,10 +392,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeCompletedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -393,10 +422,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -411,6 +442,7 @@ class TextReplaceStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextReplaceStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -473,6 +505,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
ChatbotAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -492,6 +525,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
CompletionAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -510,6 +544,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
WorkflowAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -528,10 +563,12 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
1
api/core/app/features/rate_limiting/__init__.py
Normal file
1
api/core/app/features/rate_limiting/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .rate_limit import RateLimit
|
120
api/core/app/features/rate_limiting/rate_limit.py
Normal file
120
api/core/app/features/rate_limiting/rate_limit.py
Normal file
@ -0,0 +1,120 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimit:
|
||||
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
|
||||
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
|
||||
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
|
||||
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
return cls._instance_dict[client_id]
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
self.max_active_requests = max_active_requests
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
self.initialized = True
|
||||
self.client_id = client_id
|
||||
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.last_recalculate_time = float('-inf')
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
self.last_recalculate_time = time.time()
|
||||
# flush max active requests
|
||||
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||
with redis_client.pipeline() as pipe:
|
||||
pipe.set(self.max_active_requests_key, self.max_active_requests)
|
||||
pipe.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
pipe.execute()
|
||||
else:
|
||||
with redis_client.pipeline() as pipe:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
# flush max active requests (in-transit request list)
|
||||
if not redis_client.exists(self.active_requests_key):
|
||||
return
|
||||
request_details = redis_client.hgetall(self.active_requests_key)
|
||||
redis_client.expire(self.active_requests_key, timedelta(days=1))
|
||||
timeout_requests = [k for k, v in request_details.items() if
|
||||
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
||||
if timeout_requests:
|
||||
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||
|
||||
def enter(self, request_id: Optional[str] = None) -> str:
|
||||
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
||||
self.flush_cache()
|
||||
if self.max_active_requests <= 0:
|
||||
return RateLimit._UNLIMITED_REQUEST_ID
|
||||
if not request_id:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
|
||||
active_requests_count = redis_client.hlen(self.active_requests_key)
|
||||
if active_requests_count >= self.max_active_requests:
|
||||
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests))
|
||||
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
||||
return request_id
|
||||
|
||||
def exit(self, request_id: str):
|
||||
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
|
||||
return
|
||||
redis_client.hdel(self.active_requests_key, request_id)
|
||||
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
|
||||
if isinstance(generator, dict):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(self, generator, request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
if callable(generator):
|
||||
self.generator = generator()
|
||||
else:
|
||||
self.generator = generator
|
||||
self.request_id = request_id
|
||||
self.closed = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.closed:
|
||||
raise StopIteration
|
||||
try:
|
||||
return next(self.generator)
|
||||
except StopIteration:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.rate_limit.exit(self.request_id)
|
||||
if self.generator is not None and hasattr(self.generator, 'close'):
|
||||
self.generator.close()
|
@ -4,6 +4,8 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
@ -32,6 +34,8 @@ from core.app.entities.task_entities import (
|
||||
CompletionAppStreamResponse,
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
@ -87,6 +91,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
self._model_config = application_generate_entity.model_conf
|
||||
self._app_config = application_generate_entity.app_config
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
|
||||
@ -123,7 +128,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@ -202,14 +207,64 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
if publisher is None:
|
||||
break
|
||||
audio = publisher.checkAndGetAudio()
|
||||
if audio is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio,
|
||||
task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@ -272,7 +327,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
|
@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class AppInvokeQuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
|
@ -730,7 +730,7 @@ class IndexingRunner:
|
||||
self._check_document_paused_status(dataset_document.id)
|
||||
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
|
||||
if embedding_model_instance:
|
||||
tokens += sum(
|
||||
embedding_model_instance.get_text_embedding_num_tokens(
|
||||
[document.page_content]
|
||||
|
@ -264,7 +264,7 @@ class ModelInstance:
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
@ -287,8 +287,7 @@ class ModelInstance:
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
streaming=streaming
|
||||
voice=voice
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||
|
@ -1,4 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
@ -10,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for ttstext model.
|
||||
@ -20,7 +22,7 @@ class TTSModel(AIModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -35,14 +37,15 @@ class TTSModel(AIModel):
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}")
|
||||
self._is_ffmpeg_installed()
|
||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
|
||||
return self._invoke(model=model, credentials=credentials, user=user,
|
||||
content_text=content_text, voice=voice, tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -123,26 +126,26 @@ class TTSModel(AIModel):
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
||||
if delimiters is None:
|
||||
delimiters = set('。!?;\n')
|
||||
|
||||
buf = []
|
||||
word_count = 0
|
||||
for char in text:
|
||||
buf.append(char)
|
||||
if char in delimiters:
|
||||
if word_count >= limit:
|
||||
yield ''.join(buf)
|
||||
buf = []
|
||||
word_count = 0
|
||||
else:
|
||||
word_count += 1
|
||||
else:
|
||||
word_count += 1
|
||||
|
||||
if buf:
|
||||
yield ''.join(buf)
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
|
||||
match = re.compile(pattern)
|
||||
tx = match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
one_sentence = ''
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
tmp = org_text[start:end]
|
||||
if len(one_sentence + tmp) > max_length:
|
||||
result.append(one_sentence)
|
||||
one_sentence = ''
|
||||
one_sentence += tmp
|
||||
start = end
|
||||
last_sens = org_text[start:]
|
||||
if last_sens:
|
||||
one_sentence += last_sens
|
||||
if one_sentence != '':
|
||||
result.append(one_sentence)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _is_ffmpeg_installed():
|
||||
|
@ -33,3 +33,4 @@
|
||||
- deepseek
|
||||
- hunyuan
|
||||
- siliconflow
|
||||
- perfxcloud
|
||||
|
@ -71,6 +71,9 @@ model_credential_schema:
|
||||
- label:
|
||||
en_US: '2024-02-01'
|
||||
value: '2024-02-01'
|
||||
- label:
|
||||
en_US: '2024-06-01'
|
||||
value: '2024-06-01'
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
|
@ -4,7 +4,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response
|
||||
from openai import AzureOpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
@ -14,7 +14,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
@ -23,7 +22,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
content_text: str, voice: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@ -32,30 +31,23 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
@ -82,7 +74,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
@ -107,34 +99,37 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
# max font is 4096,there is 3500 limit for each request
|
||||
max_length = 3500
|
||||
if len(content_text) > max_length:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i], voice=voice) for i in range(len(sentences))]
|
||||
for index, future in enumerate(futures):
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
|
||||
response_format="mp3",
|
||||
input=content_text.strip())
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
@ -162,7 +157,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
|
||||
for ai_model_entity in TTS_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
@ -170,5 +165,4 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
|
@ -66,6 +66,10 @@ provider_credential_schema:
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: eu-west-2
|
||||
label:
|
||||
en_US: Eu west London (London)
|
||||
zh_Hans: 欧洲西部 (伦敦)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
|
@ -48,6 +48,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
|
||||
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
|
||||
CONVERSE_API_ENABLED_MODEL_INFO=[
|
||||
{'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _find_model_info(model_id):
|
||||
for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
|
||||
if model_id.startswith(model['prefix']):
|
||||
return model
|
||||
logger.info(f"current model id: {model_id} did not support by Converse API")
|
||||
return None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
@ -66,10 +88,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# TODO: consolidate different invocation methods for models based on base model capabilities
|
||||
# invoke anthropic models via boto3 client
|
||||
if "anthropic" in model:
|
||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
|
||||
model_info= BedrockLargeLanguageModel._find_model_info(model)
|
||||
if model_info:
|
||||
model_info['model'] = model
|
||||
# invoke models via boto3 converse API
|
||||
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
# invoke Cohere models via boto3 client
|
||||
if "cohere.command-r" in model:
|
||||
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
@ -151,12 +175,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
|
||||
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke Anthropic large language model
|
||||
Invoke large language model with converse API
|
||||
|
||||
:param model: model name
|
||||
:param model_info: model information
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
@ -173,24 +197,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||
|
||||
parameters = {
|
||||
'modelId': model,
|
||||
'modelId': model_info['model'],
|
||||
'messages': prompt_message_dicts,
|
||||
'inferenceConfig': inference_config,
|
||||
'additionalModelRequestFields': additional_model_fields,
|
||||
}
|
||||
|
||||
if system and len(system) > 0:
|
||||
if model_info['support_system_prompts'] and system and len(system) > 0:
|
||||
parameters['system'] = system
|
||||
|
||||
if tools:
|
||||
if model_info['support_tool_use'] and tools:
|
||||
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
|
||||
|
||||
if stream:
|
||||
response = bedrock_client.converse_stream(**parameters)
|
||||
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
|
||||
return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
|
||||
else:
|
||||
response = bedrock_client.converse(**parameters)
|
||||
return self._handle_converse_response(model, credentials, response, prompt_messages)
|
||||
return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
|
||||
|
||||
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
@ -203,9 +227,29 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response chunk generator result
|
||||
"""
|
||||
response_content = response['output']['message']['content']
|
||||
# transform assistant message to prompt message
|
||||
if response['stopReason'] == 'tool_use':
|
||||
tool_calls = []
|
||||
text, tool_use = self._extract_tool_use(response_content)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_use['toolUseId'],
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_use['name'],
|
||||
arguments=json.dumps(tool_use['input'])
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response['output']['message']['content'][0]['text']
|
||||
content=text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
else:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response_content[0]['text']
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
@ -229,6 +273,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
return result
|
||||
|
||||
def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
|
||||
tool_use = {}
|
||||
text = ''
|
||||
for item in content:
|
||||
if 'toolUse' in item:
|
||||
tool_use = item['toolUse']
|
||||
elif 'text' in item:
|
||||
text = item['text']
|
||||
else:
|
||||
raise ValueError(f"Got unknown item: {item}")
|
||||
return text, tool_use
|
||||
|
||||
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
|
||||
prompt_messages: list[PromptMessage], ) -> Generator:
|
||||
"""
|
||||
@ -340,14 +396,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
|
||||
system = []
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content=message.content.strip()
|
||||
system.append({"text": message.content})
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
else:
|
||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||
|
||||
return system, prompt_message_dicts
|
||||
@ -448,7 +502,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
|
||||
|
@ -2,6 +2,9 @@ model: mistral.mistral-large-2402-v1:0
|
||||
label:
|
||||
en_US: Mistral Large
|
||||
model_type: llm
|
||||
features:
|
||||
- tool-call
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
|
@ -2,6 +2,8 @@ model: mistral.mistral-small-2402-v1:0
|
||||
label:
|
||||
en_US: Mistral Small
|
||||
model_type: llm
|
||||
features:
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
|
@ -7,7 +7,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
@ -7,7 +7,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
@ -21,7 +21,7 @@ model_properties:
|
||||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 120
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
@ -21,7 +21,7 @@ model_properties:
|
||||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
word_limit: 120
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
@ -3,7 +3,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response
|
||||
from openai import OpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
@ -11,7 +11,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
@ -20,7 +19,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
content_text: str, voice: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@ -29,22 +28,17 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
# if streaming:
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
@ -79,7 +73,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
@ -104,34 +98,40 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
try:
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
if len(content_text) > word_limit:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i], voice=voice) for i in range(len(sentences))]
|
||||
for index, future in enumerate(futures):
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
|
||||
response_format="mp3",
|
||||
input=content_text.strip())
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
|
@ -616,10 +616,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||
# in
|
||||
# message.tool_calls]
|
||||
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'tool_call':
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||
message.tool_calls]
|
||||
elif function_calling_type == 'function_call':
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
@ -630,13 +631,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
# message_dict = {
|
||||
# "role": "tool",
|
||||
# "content": message.content,
|
||||
# "tool_call_id": message.tool_call_id
|
||||
# }
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'tool_call':
|
||||
message_dict = {
|
||||
"role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function",
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
elif function_calling_type == 'function_call':
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
}
|
||||
|
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 21 KiB |
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 48 KiB |
@ -0,0 +1,61 @@
|
||||
model: Qwen-14B-Chat-Int4
|
||||
label:
|
||||
en_US: Qwen-14B-Chat-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
@ -0,0 +1,61 @@
|
||||
model: Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
label:
|
||||
en_US: Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 128
|
||||
min: 1
|
||||
max: 256
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
@ -0,0 +1,61 @@
|
||||
model: Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
label:
|
||||
en_US: Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
@ -0,0 +1,61 @@
|
||||
model: Qwen1.5-7B
|
||||
label:
|
||||
en_US: Qwen1.5-7B
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
@ -0,0 +1,63 @@
|
||||
model: Qwen2-72B-Instruct-GPTQ-Int4
|
||||
label:
|
||||
en_US: Qwen2-72B-Instruct-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
@ -0,0 +1,63 @@
|
||||
model: Qwen2-7B
|
||||
label:
|
||||
en_US: Qwen2-7B
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
@ -0,0 +1,6 @@
|
||||
- Qwen2-72B-Instruct-GPTQ-Int4
|
||||
- Qwen2-7B
|
||||
- Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-7B
|
||||
- Qwen-14B-Chat-Int4
|
110
api/core/model_runtime/model_providers/perfxcloud/llm/llm.py
Normal file
110
api/core/model_runtime/model_providers/perfxcloud/llm/llm.py
Normal file
@ -0,0 +1,110 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
|
||||
|
||||
|
||||
class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_string(self, model: str, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['openai_api_key']=credentials['api_key']
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['openai_api_base']='https://cloud.perfxlab.cn'
|
||||
else:
|
||||
parsed_url = urlparse(credentials['endpoint_url'])
|
||||
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"
|
@ -0,0 +1,32 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerfXCloudProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(
|
||||
model='Qwen2-72B-Instruct-GPTQ-Int4',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
@ -0,0 +1,42 @@
|
||||
provider: perfxcloud
|
||||
label:
|
||||
en_US: PerfXCloud
|
||||
zh_Hans: PerfXCloud
|
||||
description:
|
||||
en_US: PerfXCloud (Pengfeng Technology) is an AI development and deployment platform tailored for developers and enterprises, providing reasoning capabilities for multiple models.
|
||||
zh_Hans: PerfXCloud(澎峰科技)为开发者和企业量身打造的AI开发和部署平台,提供多种模型的的推理能力。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#e3f0ff"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from PerfXCloud
|
||||
zh_Hans: 从 PerfXCloud 获取 API Key
|
||||
url:
|
||||
en_US: https://cloud.perfxlab.cn/panel/token
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
zh_Hans: 自定义 API endpoint 地址
|
||||
en_US: Custom API endpoint URL
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: Base URL, e.g. https://cloud.perfxlab.cn/v1
|
||||
en_US: Base URL, e.g. https://cloud.perfxlab.cn/v1
|
@ -0,0 +1,4 @@
|
||||
model: BAAI/bge-m3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32768
|
@ -0,0 +1,250 @@
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
# Prepare headers and payload for the request
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
endpoint_url='https://cloud.perfxlab.cn/v1/'
|
||||
else:
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
extra_model_kwargs['encoding_format'] = 'float'
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
# TODO: Optimize for better token estimation and chunking
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0: cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# Prepare the payload for the request
|
||||
payload = {
|
||||
'input': inputs[i: i + max_chunks],
|
||||
'model': model,
|
||||
**extra_model_kwargs
|
||||
}
|
||||
|
||||
# Make the request to the OpenAI API
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
response_data = response.json()
|
||||
|
||||
# Extract embeddings and used tokens from the response
|
||||
embeddings_batch = [data['embedding'] for data in response_data['data']]
|
||||
embedding_used_tokens = response_data['usage']['total_tokens']
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
endpoint_url='https://cloud.perfxlab.cn/v1/'
|
||||
else:
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||
|
||||
payload = {
|
||||
'input': 'ping',
|
||||
'model': model
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=endpoint_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed with status code {response.status_code}')
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
||||
|
||||
if 'model' not in json_result:
|
||||
raise CredentialsValidateFailedError(
|
||||
'Credentials validation failed: invalid response')
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
)
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
@ -129,7 +129,7 @@ model_properties:
|
||||
- mode: "sambert-waan-v1"
|
||||
name: "Waan(泰语女声)"
|
||||
language: [ "th-TH" ]
|
||||
word_limit: 120
|
||||
word_limit: 7000
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
@ -1,17 +1,21 @@
|
||||
import concurrent.futures
|
||||
import threading
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from queue import Queue
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
from flask import Response, stream_with_context
|
||||
from dashscope import SpeechSynthesizer
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult
|
||||
from flask import Response
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
@ -19,7 +23,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
Model class for Tongyi Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
@ -29,22 +33,17 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
if not voice or voice not in [d['value'] for d in
|
||||
self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
tenant_id=tenant_id)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
@ -79,7 +78,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
@ -105,14 +104,12 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
@ -120,18 +117,32 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
"""
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_queue: Queue = Queue()
|
||||
callback = Callback(queue=audio_queue)
|
||||
|
||||
def invoke_remote(content, v, api_key, cb, at, wl):
|
||||
if len(content) < word_limit:
|
||||
sentences = [content]
|
||||
else:
|
||||
sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl))
|
||||
for sentence in sentences:
|
||||
response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000,
|
||||
api_key=credentials.get('dashscope_api_key'),
|
||||
SpeechSynthesizer.call(model=v, sample_rate=16000,
|
||||
api_key=api_key,
|
||||
text=sentence.strip(),
|
||||
format=audio_type, word_timestamp_enabled=True,
|
||||
callback=cb,
|
||||
format=at, word_timestamp_enabled=True,
|
||||
phoneme_timestamp_enabled=True)
|
||||
if isinstance(response.get_audio_data(), bytes):
|
||||
storage.save(file_path, response.get_audio_data())
|
||||
|
||||
threading.Thread(target=invoke_remote, args=(
|
||||
content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start()
|
||||
|
||||
while True:
|
||||
audio = audio_queue.get()
|
||||
if audio is None:
|
||||
break
|
||||
yield audio
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
@ -152,3 +163,29 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
format=audio_type)
|
||||
if isinstance(response.get_audio_data(), bytes):
|
||||
return response.get_audio_data()
|
||||
|
||||
|
||||
class Callback(ResultCallback):
|
||||
|
||||
def __init__(self, queue: Queue):
|
||||
self._queue = queue
|
||||
|
||||
def on_open(self):
|
||||
pass
|
||||
|
||||
def on_complete(self):
|
||||
self._queue.put(None)
|
||||
self._queue.task_done()
|
||||
|
||||
def on_error(self, response: SpeechSynthesisResponse):
|
||||
self._queue.put(None)
|
||||
self._queue.task_done()
|
||||
|
||||
def on_close(self):
|
||||
self._queue.put(None)
|
||||
self._queue.task_done()
|
||||
|
||||
def on_event(self, result: SpeechSynthesisResult):
|
||||
ad = result.get_audio_frame()
|
||||
if ad:
|
||||
self._queue.put(ad)
|
||||
|
@ -29,7 +29,7 @@ model_credential_schema:
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: secret-input
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000
|
||||
|
@ -0,0 +1,40 @@
|
||||
model: ernie-4.0-turbo-8k-preview
|
||||
label:
|
||||
en_US: Ernie-4.0-turbo-8k-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 2048
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
en_US: Disable Search
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
@ -138,6 +138,7 @@ class ErnieBotModel:
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
@ -149,6 +150,7 @@ class ErnieBotModel:
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k'
|
||||
'ernie-4.0-turbo-8k-preview'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
|
@ -32,7 +32,7 @@ model_credential_schema:
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: secret-input
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
|
||||
|
@ -20,7 +20,7 @@ class ZhipuaiProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='chatglm_turbo',
|
||||
model='glm-4',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NOT_GIVEN, Headers, NotGiven
|
||||
from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven
|
||||
from ..core._http_client import make_user_request_input
|
||||
from ..types.image import ImagesResponded
|
||||
|
||||
@ -28,7 +28,9 @@ class Images(BaseAPI):
|
||||
size: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
style: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ImagesResponded:
|
||||
@ -46,9 +48,12 @@ class Images(BaseAPI):
|
||||
"size": size,
|
||||
"style": style,
|
||||
"user": user,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
|
@ -11,7 +11,7 @@ from tenacity import retry
|
||||
from tenacity.stop import stop_after_attempt
|
||||
|
||||
from . import _errors
|
||||
from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
|
||||
from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
|
||||
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
|
||||
from ._files import make_httpx_files
|
||||
from ._request_opt import ClientRequestParam, UserRequestInput
|
||||
@ -358,6 +358,7 @@ def make_user_request_input(
|
||||
max_retries: int | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers = None,
|
||||
extra_body: Body | None = None,
|
||||
query: Query | None = None,
|
||||
) -> UserRequestInput:
|
||||
options: UserRequestInput = {}
|
||||
@ -370,5 +371,7 @@ def make_user_request_input(
|
||||
options['timeout'] = timeout
|
||||
if query is not None:
|
||||
options["params"] = query
|
||||
if extra_body is not None:
|
||||
options["extra_json"] = cast(AnyMapping, extra_body)
|
||||
|
||||
return options
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.models.document import Document
|
||||
@ -14,8 +13,8 @@ class Keyword:
|
||||
self._keyword_processor = self._init_keyword()
|
||||
|
||||
def _init_keyword(self) -> BaseKeyword:
|
||||
config = current_app.config
|
||||
keyword_type = config.get('KEYWORD_STORE')
|
||||
config = dify_config
|
||||
keyword_type = config.KEYWORD_STORE
|
||||
|
||||
if not keyword_type:
|
||||
raise ValueError("Keyword store must be specified.")
|
||||
|
0
api/core/rag/datasource/vdb/analyticdb/__init__.py
Normal file
0
api/core/rag/datasource/vdb/analyticdb/__init__.py
Normal file
332
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
Normal file
332
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
Normal file
@ -0,0 +1,332 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = ("dify",)
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
_instance = None
|
||||
_init = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||
# collection_name must be updated every time
|
||||
self._collection_name = collection_name.lower()
|
||||
if AnalyticdbVector._init:
|
||||
return
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(
|
||||
user_agent="dify", **config.to_analyticdb_client_params()
|
||||
)
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
AnalyticdbVector._init = True
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create namespace {self.config.namespace}: {e}"
|
||||
)
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create collection {self._collection_name}: {e}"
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(
|
||||
self, documents: list[Document], embeddings: list[list[float]], **kwargs
|
||||
):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'"
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(
|
||||
self, query_vector: list[float], **kwargs: Any
|
||||
) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=json.loads(match.metadata.get("metadata_")),
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=json.loads(match.metadata.get("metadata_")),
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"][
|
||||
"class_prefix"
|
||||
]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
|
||||
)
|
||||
config = current_app.config
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
AnalyticdbConfig(
|
||||
access_key_id=config.get("ANALYTICDB_KEY_ID"),
|
||||
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
|
||||
region_id=config.get("ANALYTICDB_REGION_ID"),
|
||||
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
|
||||
account=config.get("ANALYTICDB_ACCOUNT"),
|
||||
account_password=config.get("ANALYTICDB_PASSWORD"),
|
||||
namespace=config.get("ANALYTICDB_NAMESPACE"),
|
||||
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
|
||||
),
|
||||
)
|
0
api/core/rag/datasource/vdb/myscale/__init__.py
Normal file
0
api/core/rag/datasource/vdb/myscale/__init__.py
Normal file
170
api/core/rag/datasource/vdb/myscale/myscale_vector.py
Normal file
170
api/core/rag/datasource/vdb/myscale/myscale_vector.py
Normal file
@ -0,0 +1,170 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from clickhouse_connect import get_client
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
fts_params: str
|
||||
|
||||
|
||||
class SortOrder(Enum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._metric = metric
|
||||
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
|
||||
self._client = get_client(
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
username=config.user,
|
||||
password=config.password,
|
||||
)
|
||||
self._client.command("SET allow_experimental_object_type=1")
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MYSCALE
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logging.info(f"create MyScale collection {self._collection_name} with dimension {dimension}")
|
||||
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
||||
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
||||
sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
|
||||
id String,
|
||||
text String,
|
||||
vector Array(Float32),
|
||||
metadata JSON,
|
||||
CONSTRAINT cons_vec_len CHECK length(vector) = {dimension},
|
||||
VECTOR INDEX vidx vector TYPE DEFAULT('metric_type = {self._metric}'),
|
||||
INDEX text_idx text TYPE fts{fts_params}
|
||||
) ENGINE = MergeTree ORDER BY id
|
||||
"""
|
||||
self._client.command(sql)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = []
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
values = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {}
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
sql = f"""
|
||||
INSERT INTO {self._config.database}.{self._collection_name}
|
||||
({",".join(columns)}) VALUES {",".join(values)}
|
||||
"""
|
||||
self._client.command(sql)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def escape_str(value: Any) -> str:
|
||||
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value))
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
rows = self._client.query(
|
||||
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
).result_rows
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
||||
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
||||
sql = f"""
|
||||
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception as e:
|
||||
logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
|
||||
|
||||
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return MyScaleVector(
|
||||
collection_name=collection_name,
|
||||
config=MyScaleConfig(
|
||||
host=config.get("MYSCALE_HOST", "localhost"),
|
||||
port=int(config.get("MYSCALE_PORT", 8123)),
|
||||
user=config.get("MYSCALE_USER", "default"),
|
||||
password=config.get("MYSCALE_PASSWORD", ""),
|
||||
database=config.get("MYSCALE_DATABASE", "default"),
|
||||
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
|
||||
),
|
||||
)
|
@ -57,6 +57,9 @@ class Vector:
|
||||
case VectorType.MILVUS:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
return MilvusVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
return MyScaleVectorFactory
|
||||
case VectorType.PGVECTOR:
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||
return PGVectorFactory
|
||||
@ -84,6 +87,9 @@ class Vector:
|
||||
case VectorType.OPENSEARCH:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
return AnalyticdbVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
@ -2,8 +2,10 @@ from enum import Enum
|
||||
|
||||
|
||||
class VectorType(str, Enum):
|
||||
ANALYTICDB = 'analyticdb'
|
||||
CHROMA = 'chroma'
|
||||
MILVUS = 'milvus'
|
||||
MYSCALE = 'myscale'
|
||||
PGVECTOR = 'pgvector'
|
||||
PGVECTO_RS = 'pgvecto-rs'
|
||||
QDRANT = 'qdrant'
|
||||
|
@ -46,7 +46,6 @@ class FirecrawlApp:
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
start_time = time.time()
|
||||
headers = self._prepare_headers()
|
||||
json_data = {'url': url}
|
||||
if params:
|
||||
|
@ -18,8 +18,8 @@ class MarkdownExtractor(BaseExtractor):
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = True,
|
||||
remove_images: bool = True,
|
||||
remove_hyperlinks: bool = False,
|
||||
remove_images: bool = False,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
|
@ -8,7 +8,7 @@ We have defined a series of helper methods in the `Tool` class to help developer
|
||||
|
||||
### Message Return
|
||||
|
||||
Dify supports various message types such as `text`, `link`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
|
||||
Dify supports various message types such as `text`, `link`, `json`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
|
||||
|
||||
Please note, some parameters in the following interfaces will be introduced in later sections.
|
||||
|
||||
@ -67,6 +67,18 @@ If you need to return the raw data of a file, such as images, audio, video, PPT,
|
||||
"""
|
||||
```
|
||||
|
||||
#### JSON
|
||||
If you need to return a formatted JSON, you can use the following interface. This is commonly used for data transmission between nodes in a workflow, of course, in agent mode, most LLM are also able to read and understand JSON.
|
||||
|
||||
- `object` A Python dictionary object will be automatically serialized into JSON
|
||||
|
||||
```python
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
```
|
||||
|
||||
### Shortcut Tools
|
||||
|
||||
In large model applications, we have two common needs:
|
||||
|
@ -145,19 +145,25 @@ parameters: # Parameter list
|
||||
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
|
||||
- `parameters` Parameter list
|
||||
- `name` Parameter name, unique, no duplication with other parameters
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
|
||||
- `required` Required or not
|
||||
- `name` (Mandatory) Parameter name, must be unique and not duplicate with other parameters.
|
||||
- `type` (Mandatory) Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` five types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using the `secret-input` type
|
||||
- `label` (Mandatory) Parameter label, for frontend display
|
||||
- `form` (Mandatory) Form type, currently supports `llm`, `form` two types.
|
||||
- In an agent app, `llm` indicates that the parameter is inferred by the LLM itself, while `form` indicates that the parameter can be pre-set for the tool.
|
||||
- In a workflow app, both `llm` and `form` need to be filled out by the front end, but the parameters of `llm` will be used as input variables for the tool node.
|
||||
- `required` Indicates whether the parameter is required or not
|
||||
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
|
||||
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
|
||||
- `options` Parameter options
|
||||
- In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options
|
||||
- In `form` mode, when `type` is `select`, the frontend will display these options
|
||||
- `default` Default value
|
||||
- `label` Parameter label, for frontend display
|
||||
- `min` Minimum value, can be set when the parameter type is `number`.
|
||||
- `max` Maximum value, can be set when the parameter type is `number`.
|
||||
- `placeholder` The prompt text for input boxes. It can be set when the form type is `form`, and the parameter type is `string`, `number`, or `secret-input`. It supports multiple languages.
|
||||
- `human_description` Introduction for frontend display, supports multiple languages
|
||||
- `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
|
||||
- `form` Form type, currently supports `llm`, `form` two types, corresponding to Agent self-inference and frontend filling
|
||||
|
||||
|
||||
## 4. Add Tool Logic
|
||||
|
||||
@ -196,7 +202,7 @@ The overall logic of the tool is in the `_invoke` method, this method accepts tw
|
||||
|
||||
### Return Data
|
||||
|
||||
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message.
|
||||
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message. If you want to return multiple messages, you can use `[self.create_text_message('msg1'), self.create_text_message('msg2')]` to create a list of messages.
|
||||
|
||||
## 5. Add Provider Code
|
||||
|
||||
@ -205,8 +211,6 @@ Finally, we need to create a provider class under the provider module to impleme
|
||||
Create `google.py` under the `google` module, the content is as follows.
|
||||
|
||||
```python
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
### 消息返回
|
||||
|
||||
Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可以通过以下几个接口返回不同类型的消息给LLM和用户。
|
||||
Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型,你可以通过以下几个接口返回不同类型的消息给LLM和用户。
|
||||
|
||||
注意,在下面的接口中的部分参数将在后面的章节中介绍。
|
||||
|
||||
@ -67,6 +67,18 @@ Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可
|
||||
"""
|
||||
```
|
||||
|
||||
#### JSON
|
||||
如果你需要返回一个格式化的JSON,可以使用以下接口。这通常用于workflow中的节点间的数据传递,当然agent模式中,大部分大模型也都能够阅读和理解JSON。
|
||||
|
||||
- `object` 一个Python的字典对象,会被自动序列化为JSON
|
||||
|
||||
```python
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
```
|
||||
|
||||
### 快捷工具
|
||||
|
||||
在大模型应用中,我们有两种常见的需求:
|
||||
@ -97,8 +109,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可
|
||||
```python
|
||||
def get_url(self, url: str, user_agent: str = None) -> str:
|
||||
"""
|
||||
get url
|
||||
""" the crawled result
|
||||
get url from the crawled result
|
||||
"""
|
||||
```
|
||||
|
||||
### 变量池
|
||||
|
@ -140,8 +140,12 @@ parameters: # 参数列表
|
||||
|
||||
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
|
||||
- `parameters` 参数列表
|
||||
- `name` 参数名称,唯一,不允许和其他参数重名
|
||||
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
|
||||
- `name` (必填)参数名称,唯一,不允许和其他参数重名
|
||||
- `type` (必填)参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
|
||||
- `label`(必填)参数标签,用于前端展示
|
||||
- `form` (必填)表单类型,目前支持`llm`、`form`两种类型
|
||||
- 在Agent应用中,`llm`表示该参数LLM自行推理,`form`表示要使用该工具可提前设定的参数
|
||||
- 在workflow应用中,`llm`和`form`均需要前端填写,但`llm`的参数会做为工具节点的输入变量
|
||||
- `required` 是否必填
|
||||
- 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
|
||||
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
|
||||
@ -149,10 +153,12 @@ parameters: # 参数列表
|
||||
- 在`llm`模式下,Dify会将所有选项传递给LLM,LLM可以根据这些选项进行推理
|
||||
- 在`form`模式下,`type`为`select`时,前端会展示这些选项
|
||||
- `default` 默认值
|
||||
- `label` 参数标签,用于前端展示
|
||||
- `min` 最小值,当参数类型为`number`时可以设定
|
||||
- `max` 最大值,当参数类型为`number`时可以设定
|
||||
- `human_description` 用于前端展示的介绍,支持多语言
|
||||
- `placeholder` 字段输入框的提示文字,在表单类型为`form`,参数类型为`string`、`number`、`secret-input`时,可以设定,支持多语言
|
||||
- `llm_description` 传递给LLM的介绍,为了使得LLM更好理解这个参数,我们建议在这里写上关于这个参数尽可能详细的信息,让LLM能够理解这个参数
|
||||
- `form` 表单类型,目前支持`llm`、`form`两种类型,分别对应Agent自行推理和前端填写
|
||||
|
||||
|
||||
## 4. 准备工具代码
|
||||
当完成工具的配置以后,我们就可以开始编写工具代码了,主要用于实现工具的逻辑。
|
||||
@ -176,7 +182,6 @@ class GoogleSearchTool(BuiltinTool):
|
||||
query = tool_parameters['query']
|
||||
result_type = tool_parameters['result_type']
|
||||
api_key = self.runtime.credentials['serpapi_api_key']
|
||||
# TODO: search with serpapi
|
||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
||||
|
||||
if result_type == 'text':
|
||||
@ -188,7 +193,7 @@ class GoogleSearchTool(BuiltinTool):
|
||||
工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id`和`tool_parameters`,分别表示用户ID和工具参数
|
||||
|
||||
### 返回数据
|
||||
在工具返回时,你可以选择返回一个消息或者多个消息,这里我们返回一个消息,使用`create_text_message`和`create_link_message`可以创建一个文本消息或者一个链接消息。
|
||||
在工具返回时,你可以选择返回一条消息或者多个消息,这里我们返回一条消息,使用`create_text_message`和`create_link_message`可以创建一条文本消息或者一条链接消息。如需返回多条消息,可以使用列表构建,例如`[self.create_text_message('msg1'), self.create_text_message('msg2')]`
|
||||
|
||||
## 5. 准备供应商代码
|
||||
最后,我们需要在供应商模块下创建一个供应商类,用于实现供应商的凭据验证逻辑,如果凭据验证失败,将会抛出`ToolProviderCredentialValidationError`异常。
|
||||
@ -196,8 +201,6 @@ class GoogleSearchTool(BuiltinTool):
|
||||
在`google`模块下创建`google.py`,内容如下。
|
||||
|
||||
```python
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
@ -142,7 +142,8 @@ class ToolParameter(BaseModel):
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
human_description: I18nObject = Field(..., description="The description presented to the user")
|
||||
human_description: Optional[I18nObject] = Field(None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user")
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
|
0
api/core/tools/provider/builtin/cogview/__init__.py
Normal file
0
api/core/tools/provider/builtin/cogview/__init__.py
Normal file
BIN
api/core/tools/provider/builtin/cogview/_assets/icon.png
Normal file
BIN
api/core/tools/provider/builtin/cogview/_assets/icon.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user