mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 04:55:55 +08:00
fix bug
This commit is contained in:
commit
a603e01f5e
1
.gitignore
vendored
1
.gitignore
vendored
@ -174,5 +174,6 @@ sdks/python-client/dify_client.egg-info
|
|||||||
.vscode/*
|
.vscode/*
|
||||||
!.vscode/launch.json
|
!.vscode/launch.json
|
||||||
pyrightconfig.json
|
pyrightconfig.json
|
||||||
|
api/.vscode
|
||||||
|
|
||||||
.idea/
|
.idea/
|
||||||
|
@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
|||||||
|
|
||||||
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
||||||
|
|
||||||
Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot.
|
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
|
||||||
|
|
||||||
### 5. Visit dify in your browser
|
### 5. Visit dify in your browser
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ Dify 依赖以下工具和库:
|
|||||||
|
|
||||||
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
|
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
|
||||||
|
|
||||||
查看 [安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq) 以获取常见问题列表和故障排除步骤。
|
查看 [安装常见问题解答](https://docs.dify.ai/v/zh-hans/learn-more/faq/install-faq) 以获取常见问题列表和故障排除步骤。
|
||||||
|
|
||||||
### 5. 在浏览器中访问 Dify
|
### 5. 在浏览器中访问 Dify
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ Dify はバックエンドとフロントエンドから構成されています
|
|||||||
まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
|
まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
|
||||||
次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
|
次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
|
||||||
|
|
||||||
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。
|
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/v/japanese/learn-more/faq/install-faq) を確認してください。
|
||||||
|
|
||||||
### 5. ブラウザで dify にアクセスする
|
### 5. ブラウザで dify にアクセスする
|
||||||
|
|
||||||
|
@ -256,3 +256,7 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
|||||||
# App configuration
|
# App configuration
|
||||||
APP_MAX_EXECUTION_TIME=1200
|
APP_MAX_EXECUTION_TIME=1200
|
||||||
APP_MAX_ACTIVE_REQUESTS=0
|
APP_MAX_ACTIVE_REQUESTS=0
|
||||||
|
|
||||||
|
|
||||||
|
# Celery beat configuration
|
||||||
|
CELERY_BEAT_SCHEDULER_TIME=1
|
@ -1,7 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
|
|
||||||
if os.environ.get("DEBUG", "false").lower() != 'true':
|
if os.environ.get("DEBUG", "false").lower() != 'true':
|
||||||
from gevent import monkey
|
from gevent import monkey
|
||||||
|
|
||||||
@ -23,7 +21,9 @@ from flask import Flask, Response, request
|
|||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
|
import contexts
|
||||||
from commands import register_commands
|
from commands import register_commands
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
# DO NOT REMOVE BELOW
|
# DO NOT REMOVE BELOW
|
||||||
from events import event_handlers
|
from events import event_handlers
|
||||||
@ -181,7 +181,10 @@ def load_user_from_request(request_from_flask_login):
|
|||||||
decoded = PassportService().verify(auth_token)
|
decoded = PassportService().verify(auth_token)
|
||||||
user_id = decoded.get('user_id')
|
user_id = decoded.get('user_id')
|
||||||
|
|
||||||
return AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||||
|
if account:
|
||||||
|
contexts.tenant_id.set(account.current_tenant_id)
|
||||||
|
return account
|
||||||
|
|
||||||
|
|
||||||
@login_manager.unauthorized_handler
|
@login_manager.unauthorized_handler
|
||||||
|
@ -23,6 +23,7 @@ class SecurityConfig(BaseSettings):
|
|||||||
default=24,
|
default=24,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AppExecutionConfig(BaseSettings):
|
class AppExecutionConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
App Execution configs
|
App Execution configs
|
||||||
@ -405,7 +406,6 @@ class DataSetConfig(BaseSettings):
|
|||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceConfig(BaseSettings):
|
class WorkspaceConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Workspace configs
|
Workspace configs
|
||||||
@ -435,6 +435,13 @@ class ImageFormatConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CeleryBeatConfig(BaseSettings):
|
||||||
|
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||||
|
description='the time of the celery scheduler, default to 1 day',
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FeatureConfig(
|
class FeatureConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
AppExecutionConfig,
|
AppExecutionConfig,
|
||||||
@ -462,5 +469,6 @@ class FeatureConfig(
|
|||||||
|
|
||||||
# hosted services config
|
# hosted services config
|
||||||
HostedServiceConfig,
|
HostedServiceConfig,
|
||||||
|
CeleryBeatConfig,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
@ -79,7 +79,7 @@ class HostedAzureOpenAiConfig(BaseSettings):
|
|||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
|
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
|
||||||
description='',
|
description='',
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PositiveInt
|
from pydantic import BaseModel, Field, PositiveInt
|
||||||
|
|
||||||
@ -8,32 +7,32 @@ class MyScaleConfig(BaseModel):
|
|||||||
MyScale configs
|
MyScale configs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MYSCALE_HOST: Optional[str] = Field(
|
MYSCALE_HOST: str = Field(
|
||||||
description='MyScale host',
|
description='MyScale host',
|
||||||
default=None,
|
default='localhost',
|
||||||
)
|
)
|
||||||
|
|
||||||
MYSCALE_PORT: Optional[PositiveInt] = Field(
|
MYSCALE_PORT: PositiveInt = Field(
|
||||||
description='MyScale port',
|
description='MyScale port',
|
||||||
default=8123,
|
default=8123,
|
||||||
)
|
)
|
||||||
|
|
||||||
MYSCALE_USER: Optional[str] = Field(
|
MYSCALE_USER: str = Field(
|
||||||
description='MyScale user',
|
description='MyScale user',
|
||||||
default=None,
|
default='default',
|
||||||
)
|
)
|
||||||
|
|
||||||
MYSCALE_PASSWORD: Optional[str] = Field(
|
MYSCALE_PASSWORD: str = Field(
|
||||||
description='MyScale password',
|
description='MyScale password',
|
||||||
default=None,
|
default='',
|
||||||
)
|
)
|
||||||
|
|
||||||
MYSCALE_DATABASE: Optional[str] = Field(
|
MYSCALE_DATABASE: str = Field(
|
||||||
description='MyScale database name',
|
description='MyScale database name',
|
||||||
default=None,
|
default='default',
|
||||||
)
|
)
|
||||||
|
|
||||||
MYSCALE_FTS_PARAMS: Optional[str] = Field(
|
MYSCALE_FTS_PARAMS: str = Field(
|
||||||
description='MyScale fts index parameters',
|
description='MyScale fts index parameters',
|
||||||
default=None,
|
default='',
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
# TODO: Update all string in code to use this constant
|
||||||
|
HIDDEN_VALUE = '[__HIDDEN__]'
|
3
api/contexts/__init__.py
Normal file
3
api/contexts/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
@ -212,7 +212,7 @@ class AppCopyApi(Resource):
|
|||||||
parser.add_argument('icon_background', type=str, location='json')
|
parser.add_argument('icon_background', type=str, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
data = AppDslService.export_dsl(app_model=app_model)
|
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
||||||
app = AppDslService.import_and_create_new_app(
|
app = AppDslService.import_and_create_new_app(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
data=data,
|
data=data,
|
||||||
@ -234,8 +234,13 @@ class AppExportApi(Resource):
|
|||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
# Add include_secret params
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"data": AppDslService.export_dsl(app_model=app_model)
|
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
|
|||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.app.segments import factory
|
||||||
from core.errors.error import AppInvokeQuotaExceededError
|
from core.errors.error import AppInvokeQuotaExceededError
|
||||||
from fields.workflow_fields import workflow_fields
|
from fields.workflow_fields import workflow_fields
|
||||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||||
@ -41,7 +42,7 @@ class DraftWorkflowApi(Resource):
|
|||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||||
@ -64,13 +65,15 @@ class DraftWorkflowApi(Resource):
|
|||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
content_type = request.headers.get('Content-Type')
|
content_type = request.headers.get('Content-Type', '')
|
||||||
|
|
||||||
if 'application/json' in content_type:
|
if 'application/json' in content_type:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
||||||
parser.add_argument('hash', type=str, required=False, location='json')
|
parser.add_argument('hash', type=str, required=False, location='json')
|
||||||
|
# TODO: set this to required=True after frontend is updated
|
||||||
|
parser.add_argument('environment_variables', type=list, required=False, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
elif 'text/plain' in content_type:
|
elif 'text/plain' in content_type:
|
||||||
try:
|
try:
|
||||||
@ -84,7 +87,8 @@ class DraftWorkflowApi(Resource):
|
|||||||
args = {
|
args = {
|
||||||
'graph': data.get('graph'),
|
'graph': data.get('graph'),
|
||||||
'features': data.get('features'),
|
'features': data.get('features'),
|
||||||
'hash': data.get('hash')
|
'hash': data.get('hash'),
|
||||||
|
'environment_variables': data.get('environment_variables')
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return {'message': 'Invalid JSON data'}, 400
|
return {'message': 'Invalid JSON data'}, 400
|
||||||
@ -94,12 +98,15 @@ class DraftWorkflowApi(Resource):
|
|||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
environment_variables_list = args.get('environment_variables') or []
|
||||||
|
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||||
workflow = workflow_service.sync_draft_workflow(
|
workflow = workflow_service.sync_draft_workflow(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
graph=args.get('graph'),
|
graph=args['graph'],
|
||||||
features=args.get('features'),
|
features=args['features'],
|
||||||
unique_hash=args.get('hash'),
|
unique_hash=args.get('hash'),
|
||||||
account=current_user
|
account=current_user,
|
||||||
|
environment_variables=environment_variables,
|
||||||
)
|
)
|
||||||
except WorkflowHashNotEqualError:
|
except WorkflowHashNotEqualError:
|
||||||
raise DraftWorkflowNotSync()
|
raise DraftWorkflowNotSync()
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import flask_restful
|
import flask_restful
|
||||||
from flask import current_app, request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.apikey import api_key_fields, api_key_list
|
from controllers.console.apikey import api_key_fields, api_key_list
|
||||||
from controllers.console.app.error import ProviderNotInitializeError
|
from controllers.console.app.error import ProviderNotInitializeError
|
||||||
@ -530,7 +531,7 @@ class DatasetApiBaseUrlApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
return {
|
return {
|
||||||
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
|
'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
|
||||||
else request.host_url.rstrip('/')) + '/v1'
|
else request.host_url.rstrip('/')) + '/v1'
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -540,20 +541,20 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
vector_type = current_app.config['VECTOR_STORE']
|
vector_type = dify_config.VECTOR_STORE
|
||||||
match vector_type:
|
match vector_type:
|
||||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
|
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
RetrievalMethod.SEMANTIC_SEARCH
|
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
|
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
RetrievalMethod.SEMANTIC_SEARCH,
|
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||||
RetrievalMethod.HYBRID_SEARCH,
|
RetrievalMethod.HYBRID_SEARCH.value,
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
case _:
|
case _:
|
||||||
@ -569,15 +570,15 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
|
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
RetrievalMethod.SEMANTIC_SEARCH
|
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
|
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
RetrievalMethod.SEMANTIC_SEARCH,
|
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||||
RetrievalMethod.HYBRID_SEARCH,
|
RetrievalMethod.HYBRID_SEARCH.value,
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
case _:
|
case _:
|
||||||
|
@ -75,7 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if last_id is not None:
|
if last_id is not None:
|
||||||
last_segment = DocumentSegment.query.get(str(last_id))
|
last_segment = db.session.get(DocumentSegment, str(last_id))
|
||||||
if last_segment:
|
if last_segment:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
DocumentSegment.position > last_segment.position)
|
DocumentSegment.position > last_segment.position)
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from flask import current_app, request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal_with
|
from flask_restful import Resource, marshal_with
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.datasets.error import (
|
from controllers.console.datasets.error import (
|
||||||
FileTooLargeError,
|
FileTooLargeError,
|
||||||
@ -26,9 +27,9 @@ class FileApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(upload_config_fields)
|
@marshal_with(upload_config_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
|
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT
|
||||||
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
|
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
||||||
image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
|
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||||
return {
|
return {
|
||||||
'file_size_limit': file_size_limit,
|
'file_size_limit': file_size_limit,
|
||||||
'batch_count_limit': batch_count_limit,
|
'batch_count_limit': batch_count_limit,
|
||||||
@ -76,7 +77,7 @@ class FileSupportTypeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
etl_type = current_app.config['ETL_TYPE']
|
etl_type = dify_config.ETL_TYPE
|
||||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||||
return {'allowed_extensions': allowed_extensions}
|
return {'allowed_extensions': allowed_extensions}
|
||||||
|
|
||||||
|
@ -78,10 +78,12 @@ class ChatTextApi(InstalledAppResource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||||
parser.add_argument('voice', type=str, location='json')
|
parser.add_argument('voice', type=str, location='json')
|
||||||
|
parser.add_argument('text', type=str, location='json')
|
||||||
parser.add_argument('streaming', type=bool, location='json')
|
parser.add_argument('streaming', type=bool, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get('message_id')
|
message_id = args.get('message_id', None)
|
||||||
|
text = args.get('text', None)
|
||||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict):
|
and app_model.workflow.features_dict):
|
||||||
@ -95,7 +97,8 @@ class ChatTextApi(InstalledAppResource):
|
|||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
voice=voice
|
voice=voice,
|
||||||
|
text=text
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
from flask import current_app
|
|
||||||
from flask_restful import fields, marshal_with
|
from flask_restful import fields, marshal_with
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import AppUnavailableError
|
from controllers.console.app.error import AppUnavailableError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
@ -78,7 +78,7 @@ class AppParameterApi(InstalledAppResource):
|
|||||||
"transfer_methods": ["remote_url", "local_file"]
|
"transfer_methods": ["remote_url", "local_file"]
|
||||||
}}),
|
}}),
|
||||||
'system_parameters': {
|
'system_parameters': {
|
||||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from flask import current_app, session
|
from flask import session
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from libs.helper import str_len
|
from libs.helper import str_len
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
@ -40,7 +41,7 @@ class InitValidateAPI(Resource):
|
|||||||
return {'result': 'success'}, 201
|
return {'result': 'success'}, 201
|
||||||
|
|
||||||
def get_init_validate_status():
|
def get_init_validate_status():
|
||||||
if current_app.config['EDITION'] == 'SELF_HOSTED':
|
if dify_config.EDITION == 'SELF_HOSTED':
|
||||||
if os.environ.get('INIT_PASSWORD'):
|
if os.environ.get('INIT_PASSWORD'):
|
||||||
return session.get('is_init_validated') or DifySetup.query.first()
|
return session.get('is_init_validated') or DifySetup.query.first()
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import request
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from libs.helper import email, get_remote_ip, str_len
|
from libs.helper import email, get_remote_ip, str_len
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
@ -17,7 +18,7 @@ from .wraps import only_edition_self_hosted
|
|||||||
class SetupApi(Resource):
|
class SetupApi(Resource):
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
if current_app.config['EDITION'] == 'SELF_HOSTED':
|
if dify_config.EDITION == 'SELF_HOSTED':
|
||||||
setup_status = get_setup_status()
|
setup_status = get_setup_status()
|
||||||
if setup_status:
|
if setup_status:
|
||||||
return {
|
return {
|
||||||
@ -77,7 +78,7 @@ def setup_required(view):
|
|||||||
|
|
||||||
|
|
||||||
def get_setup_status():
|
def get_setup_status():
|
||||||
if current_app.config['EDITION'] == 'SELF_HOSTED':
|
if dify_config.EDITION == 'SELF_HOSTED':
|
||||||
return DifySetup.query.first()
|
return DifySetup.query.first()
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
@ -3,9 +3,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from flask import current_app
|
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
|
|
||||||
|
|
||||||
@ -15,16 +16,16 @@ class VersionApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('current_version', type=str, required=True, location='args')
|
parser.add_argument('current_version', type=str, required=True, location='args')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
check_update_url = current_app.config['CHECK_UPDATE_URL']
|
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'version': current_app.config['CURRENT_VERSION'],
|
'version': dify_config.CURRENT_VERSION,
|
||||||
'release_date': '',
|
'release_date': '',
|
||||||
'release_notes': '',
|
'release_notes': '',
|
||||||
'can_auto_update': False,
|
'can_auto_update': False,
|
||||||
'features': {
|
'features': {
|
||||||
'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'],
|
'can_replace_logo': dify_config.CAN_REPLACE_LOGO,
|
||||||
'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED']
|
'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
from flask import current_app, request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
@ -36,7 +37,7 @@ class AccountInitApi(Resource):
|
|||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
if current_app.config['EDITION'] == 'CLOUD':
|
if dify_config.EDITION == 'CLOUD':
|
||||||
parser.add_argument('invitation_code', type=str, location='json')
|
parser.add_argument('invitation_code', type=str, location='json')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -45,7 +46,7 @@ class AccountInitApi(Resource):
|
|||||||
required=True, location='json')
|
required=True, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if current_app.config['EDITION'] == 'CLOUD':
|
if dify_config.EDITION == 'CLOUD':
|
||||||
if not args['invitation_code']:
|
if not args['invitation_code']:
|
||||||
raise ValueError('invitation_code is required')
|
raise ValueError('invitation_code is required')
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from flask import current_app
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, abort, marshal_with, reqparse
|
from flask_restful import Resource, abort, marshal_with, reqparse
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
@ -48,7 +48,7 @@ class MemberInviteEmailApi(Resource):
|
|||||||
|
|
||||||
inviter = current_user
|
inviter = current_user
|
||||||
invitation_results = []
|
invitation_results = []
|
||||||
console_web_url = current_app.config.get("CONSOLE_WEB_URL")
|
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||||
for invitee_email in invitee_emails:
|
for invitee_email in invitee_emails:
|
||||||
try:
|
try:
|
||||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
|
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
|
||||||
@ -117,7 +117,7 @@ class MemberUpdateRoleApi(Resource):
|
|||||||
if not TenantAccountRole.is_valid_role(new_role):
|
if not TenantAccountRole.is_valid_role(new_role):
|
||||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||||
|
|
||||||
member = Account.query.get(str(member_id))
|
member = db.session.get(Account, str(member_id))
|
||||||
if not member:
|
if not member:
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
from flask import current_app, send_file
|
from flask import send_file
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
@ -104,7 +105,7 @@ class ToolBuiltinProviderIconApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
|
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
|
||||||
icon_cache_max_age = current_app.config.get('TOOL_ICON_CACHE_MAX_AGE')
|
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||||
|
|
||||||
class ToolApiProviderAddApi(Resource):
|
class ToolApiProviderAddApi(Resource):
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from flask import abort, current_app, request
|
from flask import abort, request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from controllers.console.workspace.error import AccountNotInitializedError
|
from controllers.console.workspace.error import AccountNotInitializedError
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.operation_service import OperationService
|
from services.operation_service import OperationService
|
||||||
@ -26,7 +27,7 @@ def account_initialization_required(view):
|
|||||||
def only_edition_cloud(view):
|
def only_edition_cloud(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if current_app.config['EDITION'] != 'CLOUD':
|
if dify_config.EDITION != 'CLOUD':
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -37,7 +38,7 @@ def only_edition_cloud(view):
|
|||||||
def only_edition_self_hosted(view):
|
def only_edition_self_hosted(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
if current_app.config['EDITION'] != 'SELF_HOSTED':
|
if dify_config.EDITION != 'SELF_HOSTED':
|
||||||
abort(404)
|
abort(404)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
@ -76,10 +76,12 @@ class TextApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||||
parser.add_argument('voice', type=str, location='json')
|
parser.add_argument('voice', type=str, location='json')
|
||||||
|
parser.add_argument('text', type=str, location='json')
|
||||||
parser.add_argument('streaming', type=bool, location='json')
|
parser.add_argument('streaming', type=bool, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get('message_id')
|
message_id = args.get('message_id', None)
|
||||||
|
text = args.get('text', None)
|
||||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict):
|
and app_model.workflow.features_dict):
|
||||||
@ -87,15 +89,15 @@ class TextApi(Resource):
|
|||||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||||
'voice')
|
|
||||||
except Exception:
|
except Exception:
|
||||||
voice = None
|
voice = None
|
||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
end_user=end_user.external_user_id,
|
end_user=end_user.external_user_id,
|
||||||
voice=voice
|
voice=voice,
|
||||||
|
text=text
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
@ -21,14 +21,43 @@ from core.errors.error import (
|
|||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
|
from models.workflow import WorkflowRun
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunApi(Resource):
|
class WorkflowRunApi(Resource):
|
||||||
|
workflow_run_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'workflow_id': fields.String,
|
||||||
|
'status': fields.String,
|
||||||
|
'inputs': fields.Raw,
|
||||||
|
'outputs': fields.Raw,
|
||||||
|
'error': fields.String,
|
||||||
|
'total_steps': fields.Integer,
|
||||||
|
'total_tokens': fields.Integer,
|
||||||
|
'created_at': fields.DateTime,
|
||||||
|
'finished_at': fields.DateTime,
|
||||||
|
'elapsed_time': fields.Float,
|
||||||
|
}
|
||||||
|
|
||||||
|
@validate_app_token
|
||||||
|
@marshal_with(workflow_run_fields)
|
||||||
|
def get(self, app_model: App, workflow_id: str):
|
||||||
|
"""
|
||||||
|
Get a workflow task running detail
|
||||||
|
"""
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
if app_mode != AppMode.WORKFLOW:
|
||||||
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
|
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
|
||||||
|
return workflow_run
|
||||||
|
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
"""
|
"""
|
||||||
@ -88,5 +117,5 @@ class WorkflowTaskStopApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(WorkflowRunApi, '/workflows/run')
|
api.add_resource(WorkflowRunApi, '/workflows/run/<string:workflow_id>', '/workflows/run')
|
||||||
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')
|
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')
|
||||||
|
@ -74,10 +74,12 @@ class TextApi(WebApiResource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||||
parser.add_argument('voice', type=str, location='json')
|
parser.add_argument('voice', type=str, location='json')
|
||||||
|
parser.add_argument('text', type=str, location='json')
|
||||||
parser.add_argument('streaming', type=bool, location='json')
|
parser.add_argument('streaming', type=bool, location='json')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get('message_id')
|
message_id = args.get('message_id', None)
|
||||||
|
text = args.get('text', None)
|
||||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||||
and app_model.workflow
|
and app_model.workflow
|
||||||
and app_model.workflow.features_dict):
|
and app_model.workflow.features_dict):
|
||||||
@ -94,7 +96,8 @@ class TextApi(WebApiResource):
|
|||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
end_user=end_user.external_user_id,
|
end_user=end_user.external_user_id,
|
||||||
voice=voice
|
voice=voice,
|
||||||
|
text=text
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -342,10 +342,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
"""
|
"""
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||||
|
args = {}
|
||||||
|
if prompt_message.function.arguments != '':
|
||||||
|
args = json.loads(prompt_message.function.arguments)
|
||||||
|
|
||||||
tool_calls.append((
|
tool_calls.append((
|
||||||
prompt_message.id,
|
prompt_message.id,
|
||||||
prompt_message.function.name,
|
prompt_message.function.name,
|
||||||
json.loads(prompt_message.function.arguments),
|
args,
|
||||||
))
|
))
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
@ -359,10 +363,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
"""
|
"""
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for prompt_message in llm_result.message.tool_calls:
|
for prompt_message in llm_result.message.tool_calls:
|
||||||
|
args = {}
|
||||||
|
if prompt_message.function.arguments != '':
|
||||||
|
args = json.loads(prompt_message.function.arguments)
|
||||||
|
|
||||||
tool_calls.append((
|
tool_calls.append((
|
||||||
prompt_message.id,
|
prompt_message.id,
|
||||||
prompt_message.function.name,
|
prompt_message.function.name,
|
||||||
json.loads(prompt_message.function.arguments),
|
args,
|
||||||
))
|
))
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Optional, Union
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom
|
from core.app.app_config.entities import AppAdditionalFeatures
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||||
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||||
@ -10,37 +11,19 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
|||||||
SuggestedQuestionsAfterAnswerConfigManager,
|
SuggestedQuestionsAfterAnswerConfigManager,
|
||||||
)
|
)
|
||||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||||
from models.model import AppMode, AppModelConfig
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
class BaseAppConfigManager:
|
class BaseAppConfigManager:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom,
|
def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures:
|
||||||
app_model_config: Union[AppModelConfig, dict],
|
|
||||||
config_dict: Optional[dict] = None) -> dict:
|
|
||||||
"""
|
|
||||||
Convert app model config to config dict
|
|
||||||
:param config_from: app model config from
|
|
||||||
:param app_model_config: app model config
|
|
||||||
:param config_dict: app model config dict
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
|
|
||||||
app_model_config_dict = app_model_config.to_dict()
|
|
||||||
config_dict = app_model_config_dict.copy()
|
|
||||||
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_features(cls, config_dict: dict, app_mode: AppMode) -> AppAdditionalFeatures:
|
|
||||||
"""
|
"""
|
||||||
Convert app config to app model config
|
Convert app config to app model config
|
||||||
|
|
||||||
:param config_dict: app config
|
:param config_dict: app config
|
||||||
:param app_mode: app mode
|
:param app_mode: app mode
|
||||||
"""
|
"""
|
||||||
config_dict = config_dict.copy()
|
config_dict = dict(config_dict.items())
|
||||||
|
|
||||||
additional_features = AppAdditionalFeatures()
|
additional_features = AppAdditionalFeatures()
|
||||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from typing import Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import FileExtraConfig
|
from core.app.app_config.entities import FileExtraConfig
|
||||||
|
|
||||||
|
|
||||||
class FileUploadConfigManager:
|
class FileUploadConfigManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(cls, config: dict, is_vision: bool = True) -> Optional[FileExtraConfig]:
|
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||||
"""
|
"""
|
||||||
Convert model config to model config
|
Convert model config to model config
|
||||||
|
|
||||||
|
@ -3,13 +3,13 @@ from core.app.app_config.entities import TextToSpeechEntity
|
|||||||
|
|
||||||
class TextToSpeechConfigManager:
|
class TextToSpeechConfigManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(cls, config: dict) -> bool:
|
def convert(cls, config: dict):
|
||||||
"""
|
"""
|
||||||
Convert model config to model config
|
Convert model config to model config
|
||||||
|
|
||||||
:param config: model config args
|
:param config: model config args
|
||||||
"""
|
"""
|
||||||
text_to_speech = False
|
text_to_speech = None
|
||||||
text_to_speech_dict = config.get('text_to_speech')
|
text_to_speech_dict = config.get('text_to_speech')
|
||||||
if text_to_speech_dict:
|
if text_to_speech_dict:
|
||||||
if text_to_speech_dict.get('enabled'):
|
if text_to_speech_dict.get('enabled'):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@ -8,6 +9,7 @@ from typing import Union
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
import contexts
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||||
@ -107,6 +109,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
extras=extras,
|
extras=extras,
|
||||||
trace_manager=trace_manager
|
trace_manager=trace_manager
|
||||||
)
|
)
|
||||||
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
@ -173,6 +176,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
inputs=args['inputs']
|
inputs=args['inputs']
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
@ -225,6 +229,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
'queue_manager': queue_manager,
|
'queue_manager': queue_manager,
|
||||||
'conversation_id': conversation.id,
|
'conversation_id': conversation.id,
|
||||||
'message_id': message.id,
|
'message_id': message.id,
|
||||||
|
'user': user,
|
||||||
|
'context': contextvars.copy_context()
|
||||||
})
|
})
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
@ -249,7 +255,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
message_id: str) -> None:
|
message_id: str,
|
||||||
|
user: Account,
|
||||||
|
context: contextvars.Context) -> None:
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
:param flask_app: Flask app
|
:param flask_app: Flask app
|
||||||
@ -259,6 +267,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
:param message_id: message ID
|
:param message_id: message ID
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
for var, val in context.items():
|
||||||
|
var.set(val)
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
try:
|
try:
|
||||||
runner = AdvancedChatAppRunner()
|
runner = AdvancedChatAppRunner()
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Optional, cast
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||||
@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationException
|
||||||
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -86,7 +88,7 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
workflow=workflow
|
workflow=workflow
|
||||||
)]
|
)]
|
||||||
@ -160,7 +162,7 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
self, queue_manager: AppQueueManager,
|
self, queue_manager: AppQueueManager,
|
||||||
app_record: App,
|
app_record: App,
|
||||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
inputs: dict,
|
inputs: Mapping[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
message_id: str
|
message_id: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
|
AppBlockingResponse,
|
||||||
|
AppStreamResponse,
|
||||||
ChatbotAppBlockingResponse,
|
ChatbotAppBlockingResponse,
|
||||||
ChatbotAppStreamResponse,
|
ChatbotAppStreamResponse,
|
||||||
ErrorStreamResponse,
|
ErrorStreamResponse,
|
||||||
@ -18,12 +20,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
_blocking_response_type = ChatbotAppBlockingResponse
|
_blocking_response_type = ChatbotAppBlockingResponse
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert blocking full response.
|
Convert blocking full response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||||
response = {
|
response = {
|
||||||
'event': 'message',
|
'event': 'message',
|
||||||
'task_id': blocking_response.task_id,
|
'task_id': blocking_response.task_id,
|
||||||
@ -39,7 +42,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert blocking simple response.
|
Convert blocking simple response.
|
||||||
:param blocking_response: blocking response
|
:param blocking_response: blocking response
|
||||||
@ -53,8 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||||
-> Generator[str, None, None]:
|
|
||||||
"""
|
"""
|
||||||
Convert stream full response.
|
Convert stream full response.
|
||||||
:param stream_response: stream response
|
:param stream_response: stream response
|
||||||
@ -83,8 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
yield json.dumps(response_chunk)
|
yield json.dumps(response_chunk)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||||
-> Generator[str, None, None]:
|
|
||||||
"""
|
"""
|
||||||
Convert stream simple response.
|
Convert stream simple response.
|
||||||
:param stream_response: stream response
|
:param stream_response: stream response
|
||||||
|
@ -113,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||||
self._conversation_name_generate_thread = None
|
self._conversation_name_generate_thread = None
|
||||||
|
|
||||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
def process(self):
|
||||||
"""
|
"""
|
||||||
Process generate task pipeline.
|
Process generate task pipeline.
|
||||||
:return:
|
:return:
|
||||||
@ -136,8 +136,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
|
|
||||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||||
-> ChatbotAppBlockingResponse:
|
|
||||||
"""
|
"""
|
||||||
Process blocking response.
|
Process blocking response.
|
||||||
:return:
|
:return:
|
||||||
@ -167,8 +166,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
|
|
||||||
raise Exception('Queue listening stopped unexpectedly.')
|
raise Exception('Queue listening stopped unexpectedly.')
|
||||||
|
|
||||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
|
||||||
"""
|
"""
|
||||||
To stream response.
|
To stream response.
|
||||||
:return:
|
:return:
|
||||||
|
@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||||
|
|
||||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||||
@ -15,44 +15,41 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def convert(cls, response: Union[
|
def convert(cls, response: Union[
|
||||||
AppBlockingResponse,
|
AppBlockingResponse,
|
||||||
Generator[AppStreamResponse, None, None]
|
Generator[AppStreamResponse, Any, None]
|
||||||
], invoke_from: InvokeFrom) -> Union[
|
], invoke_from: InvokeFrom):
|
||||||
dict,
|
|
||||||
Generator[str, None, None]
|
|
||||||
]:
|
|
||||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||||
if isinstance(response, cls._blocking_response_type):
|
if isinstance(response, AppBlockingResponse):
|
||||||
return cls.convert_blocking_full_response(response)
|
return cls.convert_blocking_full_response(response)
|
||||||
else:
|
else:
|
||||||
def _generate():
|
def _generate_full_response() -> Generator[str, Any, None]:
|
||||||
for chunk in cls.convert_stream_full_response(response):
|
for chunk in cls.convert_stream_full_response(response):
|
||||||
if chunk == 'ping':
|
if chunk == 'ping':
|
||||||
yield f'event: {chunk}\n\n'
|
yield f'event: {chunk}\n\n'
|
||||||
else:
|
else:
|
||||||
yield f'data: {chunk}\n\n'
|
yield f'data: {chunk}\n\n'
|
||||||
|
|
||||||
return _generate()
|
return _generate_full_response()
|
||||||
else:
|
else:
|
||||||
if isinstance(response, cls._blocking_response_type):
|
if isinstance(response, AppBlockingResponse):
|
||||||
return cls.convert_blocking_simple_response(response)
|
return cls.convert_blocking_simple_response(response)
|
||||||
else:
|
else:
|
||||||
def _generate():
|
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||||
for chunk in cls.convert_stream_simple_response(response):
|
for chunk in cls.convert_stream_simple_response(response):
|
||||||
if chunk == 'ping':
|
if chunk == 'ping':
|
||||||
yield f'event: {chunk}\n\n'
|
yield f'event: {chunk}\n\n'
|
||||||
else:
|
else:
|
||||||
yield f'data: {chunk}\n\n'
|
yield f'data: {chunk}\n\n'
|
||||||
|
|
||||||
return _generate()
|
return _generate_simple_response()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict:
|
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict:
|
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -68,7 +65,7 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_simple_metadata(cls, metadata: dict) -> dict:
|
def _get_simple_metadata(cls, metadata: dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Get simple metadata.
|
Get simple metadata.
|
||||||
:param metadata: metadata
|
:param metadata: metadata
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@ -8,6 +9,7 @@ from typing import Union
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
import contexts
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||||
@ -38,7 +40,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
) -> Union[dict, Generator[dict, None, None]]:
|
):
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@ -86,6 +88,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
call_depth=call_depth,
|
call_depth=call_depth,
|
||||||
trace_manager=trace_manager
|
trace_manager=trace_manager
|
||||||
)
|
)
|
||||||
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
@ -126,7 +129,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(),
|
||||||
'application_generate_entity': application_generate_entity,
|
'application_generate_entity': application_generate_entity,
|
||||||
'queue_manager': queue_manager
|
'queue_manager': queue_manager,
|
||||||
|
'context': contextvars.copy_context()
|
||||||
})
|
})
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
@ -150,8 +154,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
node_id: str,
|
node_id: str,
|
||||||
user: Account,
|
user: Account,
|
||||||
args: dict,
|
args: dict,
|
||||||
stream: bool = True) \
|
stream: bool = True):
|
||||||
-> Union[dict, Generator[dict, None, None]]:
|
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
|
||||||
@ -193,6 +196,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
inputs=args['inputs']
|
inputs=args['inputs']
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
@ -205,7 +209,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
def _generate_worker(self, flask_app: Flask,
|
def _generate_worker(self, flask_app: Flask,
|
||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
queue_manager: AppQueueManager) -> None:
|
queue_manager: AppQueueManager,
|
||||||
|
context: contextvars.Context) -> None:
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
:param flask_app: Flask app
|
:param flask_app: Flask app
|
||||||
@ -213,6 +218,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
:param queue_manager: queue manager
|
:param queue_manager: queue manager
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
for var, val in context.items():
|
||||||
|
var.set(val)
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
try:
|
try:
|
||||||
# workflow app
|
# workflow app
|
||||||
|
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -56,7 +57,7 @@ class WorkflowAppRunner:
|
|||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
workflow=workflow
|
workflow=workflow
|
||||||
)]
|
)]
|
||||||
|
@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||||
|
|
||||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from core.app.entities.queue_entities import AppQueueEvent
|
from core.app.entities.queue_entities import AppQueueEvent
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ _TEXT_COLOR_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class WorkflowLoggingCallback(BaseWorkflowCallback):
|
class WorkflowLoggingCallback(WorkflowCallback):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.current_node_id = None
|
self.current_node_id = None
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -76,7 +77,7 @@ class AppGenerateEntity(BaseModel):
|
|||||||
# app config
|
# app config
|
||||||
app_config: AppConfig
|
app_config: AppConfig
|
||||||
|
|
||||||
inputs: dict[str, Any]
|
inputs: Mapping[str, Any]
|
||||||
files: list[FileVar] = []
|
files: list[FileVar] = []
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
@ -140,7 +141,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
|||||||
app_config: WorkflowUIBasedAppConfig
|
app_config: WorkflowUIBasedAppConfig
|
||||||
|
|
||||||
conversation_id: Optional[str] = None
|
conversation_id: Optional[str] = None
|
||||||
query: Optional[str] = None
|
query: str
|
||||||
|
|
||||||
class SingleIterationRunEntity(BaseModel):
|
class SingleIterationRunEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
27
api/core/app/segments/__init__.py
Normal file
27
api/core/app/segments/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from .segment_group import SegmentGroup
|
||||||
|
from .segments import Segment
|
||||||
|
from .types import SegmentType
|
||||||
|
from .variables import (
|
||||||
|
ArrayVariable,
|
||||||
|
FileVariable,
|
||||||
|
FloatVariable,
|
||||||
|
IntegerVariable,
|
||||||
|
ObjectVariable,
|
||||||
|
SecretVariable,
|
||||||
|
StringVariable,
|
||||||
|
Variable,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'IntegerVariable',
|
||||||
|
'FloatVariable',
|
||||||
|
'ObjectVariable',
|
||||||
|
'SecretVariable',
|
||||||
|
'FileVariable',
|
||||||
|
'StringVariable',
|
||||||
|
'ArrayVariable',
|
||||||
|
'Variable',
|
||||||
|
'SegmentType',
|
||||||
|
'SegmentGroup',
|
||||||
|
'Segment'
|
||||||
|
]
|
64
api/core/app/segments/factory.py
Normal file
64
api/core/app/segments/factory.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
from .segments import Segment, StringSegment
|
||||||
|
from .types import SegmentType
|
||||||
|
from .variables import (
|
||||||
|
ArrayVariable,
|
||||||
|
FileVariable,
|
||||||
|
FloatVariable,
|
||||||
|
IntegerVariable,
|
||||||
|
ObjectVariable,
|
||||||
|
SecretVariable,
|
||||||
|
StringVariable,
|
||||||
|
Variable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
|
||||||
|
if (value_type := m.get('value_type')) is None:
|
||||||
|
raise ValueError('missing value type')
|
||||||
|
if not m.get('name'):
|
||||||
|
raise ValueError('missing name')
|
||||||
|
if (value := m.get('value')) is None:
|
||||||
|
raise ValueError('missing value')
|
||||||
|
match value_type:
|
||||||
|
case SegmentType.STRING:
|
||||||
|
return StringVariable.model_validate(m)
|
||||||
|
case SegmentType.NUMBER if isinstance(value, int):
|
||||||
|
return IntegerVariable.model_validate(m)
|
||||||
|
case SegmentType.NUMBER if isinstance(value, float):
|
||||||
|
return FloatVariable.model_validate(m)
|
||||||
|
case SegmentType.SECRET:
|
||||||
|
return SecretVariable.model_validate(m)
|
||||||
|
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||||
|
raise ValueError(f'invalid number value {value}')
|
||||||
|
raise ValueError(f'not supported value type {value_type}')
|
||||||
|
|
||||||
|
|
||||||
|
def build_anonymous_variable(value: Any, /) -> Variable:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return StringVariable(name='anonymous', value=value)
|
||||||
|
if isinstance(value, int):
|
||||||
|
return IntegerVariable(name='anonymous', value=value)
|
||||||
|
if isinstance(value, float):
|
||||||
|
return FloatVariable(name='anonymous', value=value)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# TODO: Limit the depth of the object
|
||||||
|
obj = {k: build_anonymous_variable(v) for k, v in value.items()}
|
||||||
|
return ObjectVariable(name='anonymous', value=obj)
|
||||||
|
if isinstance(value, list):
|
||||||
|
# TODO: Limit the depth of the array
|
||||||
|
elements = [build_anonymous_variable(v) for v in value]
|
||||||
|
return ArrayVariable(name='anonymous', value=elements)
|
||||||
|
if isinstance(value, FileVar):
|
||||||
|
return FileVariable(name='anonymous', value=value)
|
||||||
|
raise ValueError(f'not supported value {value}')
|
||||||
|
|
||||||
|
|
||||||
|
def build_segment(value: Any, /) -> Segment:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return StringSegment(value=value)
|
||||||
|
raise ValueError(f'not supported value {value}')
|
17
api/core/app/segments/parser.py
Normal file
17
api/core/app/segments/parser.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from core.app.segments import SegmentGroup, factory
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
|
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
|
||||||
|
|
||||||
|
|
||||||
|
def convert_template(*, template: str, variable_pool: VariablePool):
|
||||||
|
parts = re.split(VARIABLE_PATTERN, template)
|
||||||
|
segments = []
|
||||||
|
for part in parts:
|
||||||
|
if '.' in part and (value := variable_pool.get(part.split('.'))):
|
||||||
|
segments.append(value)
|
||||||
|
else:
|
||||||
|
segments.append(factory.build_segment(part))
|
||||||
|
return SegmentGroup(segments=segments)
|
19
api/core/app/segments/segment_group.py
Normal file
19
api/core/app/segments/segment_group.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .segments import Segment
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentGroup(BaseModel):
|
||||||
|
segments: list[Segment]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
return ''.join([segment.text for segment in self.segments])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self):
|
||||||
|
return ''.join([segment.log for segment in self.segments])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self):
|
||||||
|
return ''.join([segment.markdown for segment in self.segments])
|
39
api/core/app/segments/segments.py
Normal file
39
api/core/app/segments/segments.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
|
class Segment(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
value_type: SegmentType
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
@field_validator('value_type')
|
||||||
|
def validate_value_type(cls, value):
|
||||||
|
"""
|
||||||
|
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||||
|
If the value is different, a ValueError is raised.
|
||||||
|
"""
|
||||||
|
if value != cls.model_fields['value_type'].default:
|
||||||
|
raise ValueError("Cannot modify 'value_type'")
|
||||||
|
return value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class StringSegment(Segment):
|
||||||
|
value_type: SegmentType = SegmentType.STRING
|
||||||
|
value: str
|
17
api/core/app/segments/types.py
Normal file
17
api/core/app/segments/types.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentType(str, Enum):
|
||||||
|
STRING = 'string'
|
||||||
|
NUMBER = 'number'
|
||||||
|
FILE = 'file'
|
||||||
|
|
||||||
|
SECRET = 'secret'
|
||||||
|
|
||||||
|
OBJECT = 'object'
|
||||||
|
|
||||||
|
ARRAY = 'array'
|
||||||
|
ARRAY_STRING = 'array[string]'
|
||||||
|
ARRAY_NUMBER = 'array[number]'
|
||||||
|
ARRAY_OBJECT = 'array[object]'
|
||||||
|
ARRAY_FILE = 'array[file]'
|
83
api/core/app/segments/variables.py
Normal file
83
api/core/app/segments/variables.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
from core.helper import encrypter
|
||||||
|
|
||||||
|
from .segments import Segment, StringSegment
|
||||||
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
|
class Variable(Segment):
|
||||||
|
"""
|
||||||
|
A variable is a segment that has a name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str = Field(
|
||||||
|
default='',
|
||||||
|
description="Unique identity for variable. It's only used by environment variables now.",
|
||||||
|
)
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class StringVariable(StringSegment, Variable):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FloatVariable(Variable):
|
||||||
|
value_type: SegmentType = SegmentType.NUMBER
|
||||||
|
value: float
|
||||||
|
|
||||||
|
|
||||||
|
class IntegerVariable(Variable):
|
||||||
|
value_type: SegmentType = SegmentType.NUMBER
|
||||||
|
value: int
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectVariable(Variable):
|
||||||
|
value_type: SegmentType = SegmentType.OBJECT
|
||||||
|
value: Mapping[str, Variable]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
# TODO: Process variables.
|
||||||
|
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
# TODO: Process variables.
|
||||||
|
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
# TODO: Use markdown code block
|
||||||
|
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayVariable(Variable):
|
||||||
|
value_type: SegmentType = SegmentType.ARRAY
|
||||||
|
value: Sequence[Variable]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||||
|
|
||||||
|
|
||||||
|
class FileVariable(Variable):
|
||||||
|
value_type: SegmentType = SegmentType.FILE
|
||||||
|
# TODO: embed FileVar in this model.
|
||||||
|
value: FileVar
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
return self.value.to_markdown()
|
||||||
|
|
||||||
|
|
||||||
|
class SecretVariable(StringVariable):
|
||||||
|
value_type: SegmentType = SegmentType.SECRET
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
return encrypter.obfuscated_token(self.value)
|
@ -1,9 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Optional, TextIO, Union
|
from typing import Any, Optional, TextIO, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
|
||||||
_TEXT_COLOR_MAPPING = {
|
_TEXT_COLOR_MAPPING = {
|
||||||
"blue": "36;1",
|
"blue": "36;1",
|
||||||
@ -43,7 +45,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_inputs: dict[str, Any],
|
tool_inputs: Mapping[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||||
@ -51,8 +53,8 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_inputs: dict[str, Any],
|
tool_inputs: Mapping[str, Any],
|
||||||
tool_outputs: str,
|
tool_outputs: Sequence[ToolInvokeMessage],
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
timer: Optional[Any] = None,
|
timer: Optional[Any] = None,
|
||||||
trace_manager: Optional[TraceQueueManager] = None
|
trace_manager: Optional[TraceQueueManager] = None
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Union
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class MessageFileParser:
|
|||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.app_id = app_id
|
self.app_id = app_id
|
||||||
|
|
||||||
def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig,
|
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
|
||||||
user: Union[Account, EndUser]) -> list[FileVar]:
|
user: Union[Account, EndUser]) -> list[FileVar]:
|
||||||
"""
|
"""
|
||||||
validate and transform files arg
|
validate and transform files arg
|
||||||
|
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
|||||||
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
||||||
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
||||||
|
|
||||||
CODE_EXECUTION_TIMEOUT= (10, 60)
|
CODE_EXECUTION_TIMEOUT = (10, 60)
|
||||||
|
|
||||||
class CodeExecutionException(Exception):
|
class CodeExecutionException(Exception):
|
||||||
pass
|
pass
|
||||||
@ -64,7 +64,7 @@ class CodeExecutor:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute_code(cls,
|
def execute_code(cls,
|
||||||
language: Literal['python3', 'javascript', 'jinja2'],
|
language: CodeLanguage,
|
||||||
preload: str,
|
preload: str,
|
||||||
code: str,
|
code: str,
|
||||||
dependencies: Optional[list[CodeDependency]] = None) -> str:
|
dependencies: Optional[list[CodeDependency]] = None) -> str:
|
||||||
@ -119,7 +119,7 @@ class CodeExecutor:
|
|||||||
return response.data.stdout
|
return response.data.stdout
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Execute code
|
Execute code
|
||||||
:param language: code language
|
:param language: code language
|
||||||
|
@ -6,11 +6,16 @@ from models.account import Tenant
|
|||||||
|
|
||||||
|
|
||||||
def obfuscated_token(token: str):
|
def obfuscated_token(token: str):
|
||||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
if not token:
|
||||||
|
return token
|
||||||
|
if len(token) <= 8:
|
||||||
|
return '*' * 20
|
||||||
|
return token[:6] + '*' * 12 + token[-2:]
|
||||||
|
|
||||||
|
|
||||||
def encrypt_token(tenant_id: str, token: str):
|
def encrypt_token(tenant_id: str, token: str):
|
||||||
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||||
|
raise ValueError(f'Tenant with id {tenant_id} not found')
|
||||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||||
return base64.b64encode(encrypted_token).decode()
|
return base64.b64encode(encrypted_token).decode()
|
||||||
|
|
||||||
|
@ -14,6 +14,9 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
|||||||
:return: a dict with name as key and index as value
|
:return: a dict with name as key and index as value
|
||||||
"""
|
"""
|
||||||
position_file_name = os.path.join(folder_path, file_name)
|
position_file_name = os.path.join(folder_path, file_name)
|
||||||
|
if not position_file_name or not os.path.exists(position_file_name):
|
||||||
|
return {}
|
||||||
|
|
||||||
positions = load_yaml_file(position_file_name, ignore_error=True)
|
positions = load_yaml_file(position_file_name, ignore_error=True)
|
||||||
position_map = {}
|
position_map = {}
|
||||||
index = 0
|
index = 0
|
||||||
|
@ -64,6 +64,7 @@ User Input:
|
|||||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||||
"Please help me predict the three most likely questions that human would ask, "
|
"Please help me predict the three most likely questions that human would ask, "
|
||||||
"and keeping each question under 20 characters.\n"
|
"and keeping each question under 20 characters.\n"
|
||||||
|
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
|
||||||
"The output must be an array in JSON format following the specified schema:\n"
|
"The output must be an array in JSON format following the specified schema:\n"
|
||||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||||
)
|
)
|
||||||
|
@ -103,7 +103,7 @@ class TokenBufferMemory:
|
|||||||
|
|
||||||
if curr_message_tokens > max_token_limit:
|
if curr_message_tokens > max_token_limit:
|
||||||
pruned_memory = []
|
pruned_memory = []
|
||||||
while curr_message_tokens > max_token_limit and prompt_messages:
|
while curr_message_tokens > max_token_limit and len(prompt_messages)>1:
|
||||||
pruned_memory.append(prompt_messages.pop(0))
|
pruned_memory.append(prompt_messages.pop(0))
|
||||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(
|
curr_message_tokens = self.model_instance.get_llm_num_tokens(
|
||||||
prompt_messages
|
prompt_messages
|
||||||
|
@ -413,6 +413,7 @@ class LBModelManager:
|
|||||||
for load_balancing_config in self._load_balancing_configs:
|
for load_balancing_config in self._load_balancing_configs:
|
||||||
if load_balancing_config.name == "__inherit__":
|
if load_balancing_config.name == "__inherit__":
|
||||||
if not managed_credentials:
|
if not managed_credentials:
|
||||||
|
# FIXME: Mutation to loop iterable `self._load_balancing_configs` during iteration
|
||||||
# remove __inherit__ if managed credentials is not provided
|
# remove __inherit__ if managed credentials is not provided
|
||||||
self._load_balancing_configs.remove(load_balancing_config)
|
self._load_balancing_configs.remove(load_balancing_config)
|
||||||
else:
|
else:
|
||||||
|
@ -27,9 +27,9 @@ parameter_rules:
|
|||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
required: true
|
||||||
default: 4096
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
|
@ -113,6 +113,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
if system:
|
if system:
|
||||||
extra_model_kwargs['system'] = system
|
extra_model_kwargs['system'] = system
|
||||||
|
|
||||||
|
# Add the new header for claude-3-5-sonnet-20240620 model
|
||||||
|
extra_headers = {}
|
||||||
|
if model == "claude-3-5-sonnet-20240620":
|
||||||
|
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
extra_model_kwargs['tools'] = [
|
extra_model_kwargs['tools'] = [
|
||||||
self._transform_tool_prompt(tool) for tool in tools
|
self._transform_tool_prompt(tool) for tool in tools
|
||||||
@ -121,6 +126,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
model=model,
|
model=model,
|
||||||
messages=prompt_message_dicts,
|
messages=prompt_message_dicts,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
extra_headers=extra_headers,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
**extra_model_kwargs
|
**extra_model_kwargs
|
||||||
)
|
)
|
||||||
@ -130,6 +136,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
model=model,
|
model=model,
|
||||||
messages=prompt_message_dicts,
|
messages=prompt_message_dicts,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
extra_headers=extra_headers,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
**extra_model_kwargs
|
**extra_model_kwargs
|
||||||
)
|
)
|
||||||
@ -138,7 +145,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||||
|
@ -501,7 +501,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
message_dict = {"role": "user", "content": sub_messages}
|
message_dict = {"role": "user", "content": sub_messages}
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message = cast(AssistantPromptMessage, message)
|
# message = cast(AssistantPromptMessage, message)
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
- gpt-4
|
- gpt-4
|
||||||
- gpt-4o
|
- gpt-4o
|
||||||
- gpt-4o-2024-05-13
|
- gpt-4o-2024-05-13
|
||||||
|
- gpt-4o-mini
|
||||||
|
- gpt-4o-mini-2024-07-18
|
||||||
- gpt-4-turbo
|
- gpt-4-turbo
|
||||||
- gpt-4-turbo-2024-04-09
|
- gpt-4-turbo-2024-04-09
|
||||||
- gpt-4-turbo-preview
|
- gpt-4-turbo-preview
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
model: gpt-4o-mini-2024-07-18
|
||||||
|
label:
|
||||||
|
zh_Hans: gpt-4o-mini-2024-07-18
|
||||||
|
en_US: gpt-4o-mini-2024-07-18
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 512
|
||||||
|
min: 1
|
||||||
|
max: 16384
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
|
pricing:
|
||||||
|
input: '0.15'
|
||||||
|
output: '0.60'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,44 @@
|
|||||||
|
model: gpt-4o-mini
|
||||||
|
label:
|
||||||
|
zh_Hans: gpt-4o-mini
|
||||||
|
en_US: gpt-4o-mini
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 512
|
||||||
|
min: 1
|
||||||
|
max: 16384
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
|
pricing:
|
||||||
|
input: '0.15'
|
||||||
|
output: '0.60'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -1,4 +1,5 @@
|
|||||||
- openai/gpt-4o
|
- openai/gpt-4o
|
||||||
|
- openai/gpt-4o-mini
|
||||||
- openai/gpt-4
|
- openai/gpt-4
|
||||||
- openai/gpt-4-32k
|
- openai/gpt-4-32k
|
||||||
- openai/gpt-3.5-turbo
|
- openai/gpt-3.5-turbo
|
||||||
|
@ -0,0 +1,43 @@
|
|||||||
|
model: openai/gpt-4o-mini
|
||||||
|
label:
|
||||||
|
en_US: gpt-4o-mini
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 512
|
||||||
|
min: 1
|
||||||
|
max: 16384
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
|
pricing:
|
||||||
|
input: "0.15"
|
||||||
|
output: "0.60"
|
||||||
|
unit: "0.000001"
|
||||||
|
currency: USD
|
Binary file not shown.
After Width: | Height: | Size: 9.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 9.5 KiB |
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||||
|
"""
|
||||||
|
Model class for Cohere large language model.
|
||||||
|
"""
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
|
-> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
# get model mode
|
||||||
|
model_mode = self.get_model_mode(model, credentials)
|
||||||
|
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
access_key = credentials.get('access_key')
|
||||||
|
secret_key = credentials.get('secret_key')
|
||||||
|
aws_region = credentials.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
if access_key and secret_key:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
|
||||||
|
|
||||||
|
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||||
|
response_model = self.sagemaker_client.invoke_endpoint(
|
||||||
|
EndpointName=sagemaker_endpoint,
|
||||||
|
Body=json.dumps(
|
||||||
|
{
|
||||||
|
"inputs": prompt_messages[0].content,
|
||||||
|
"parameters": { "stop" : stop},
|
||||||
|
"history" : []
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_text = response_model['Body'].read().decode('utf8')
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=assistant_text
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(model, credentials, 0, 0)
|
||||||
|
|
||||||
|
response = LLMResult(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# get model mode
|
||||||
|
model_mode = self.get_model_mode(model)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# get model mode
|
||||||
|
model_mode = self.get_model_mode(model)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
InvokeBadRequestError,
|
||||||
|
KeyError,
|
||||||
|
ValueError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
rules = [
|
||||||
|
ParameterRule(
|
||||||
|
name='temperature',
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
use_template='temperature',
|
||||||
|
label=I18nObject(
|
||||||
|
zh_Hans='温度',
|
||||||
|
en_US='Temperature'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name='top_p',
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
use_template='top_p',
|
||||||
|
label=I18nObject(
|
||||||
|
zh_Hans='Top P',
|
||||||
|
en_US='Top P'
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name='max_tokens',
|
||||||
|
type=ParameterType.INT,
|
||||||
|
use_template='max_tokens',
|
||||||
|
min=1,
|
||||||
|
max=credentials.get('context_length', 2048),
|
||||||
|
default=512,
|
||||||
|
label=I18nObject(
|
||||||
|
zh_Hans='最大生成长度',
|
||||||
|
en_US='Max Tokens'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
completion_type = LLMMode.value_of(credentials["mode"])
|
||||||
|
|
||||||
|
if completion_type == LLMMode.CHAT:
|
||||||
|
print(f"completion_type : {LLMMode.CHAT.value}")
|
||||||
|
|
||||||
|
if completion_type == LLMMode.COMPLETION:
|
||||||
|
print(f"completion_type : {LLMMode.COMPLETION.value}")
|
||||||
|
|
||||||
|
features = []
|
||||||
|
|
||||||
|
support_function_call = credentials.get('support_function_call', False)
|
||||||
|
if support_function_call:
|
||||||
|
features.append(ModelFeature.TOOL_CALL)
|
||||||
|
|
||||||
|
support_vision = credentials.get('support_vision', False)
|
||||||
|
if support_vision:
|
||||||
|
features.append(ModelFeature.VISION)
|
||||||
|
|
||||||
|
context_length = credentials.get('context_length', 2048)
|
||||||
|
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=features,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.MODE: completion_type,
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: context_length
|
||||||
|
},
|
||||||
|
parameter_rules=rules
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -0,0 +1,190 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class SageMakerRerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
Model class for Cohere rerank model.
|
||||||
|
"""
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
|
||||||
|
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
|
||||||
|
inputs = [query_input]*len(docs)
|
||||||
|
response_model = self.sagemaker_client.invoke_endpoint(
|
||||||
|
EndpointName=rerank_endpoint,
|
||||||
|
Body=json.dumps(
|
||||||
|
{
|
||||||
|
"inputs": inputs,
|
||||||
|
"docs": docs
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
json_str = response_model['Body'].read().decode('utf8')
|
||||||
|
json_obj = json.loads(json_str)
|
||||||
|
scores = json_obj['scores']
|
||||||
|
return scores if isinstance(scores, list) else [scores]
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None) \
|
||||||
|
-> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
line = 0
|
||||||
|
try:
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(
|
||||||
|
model=model,
|
||||||
|
docs=docs
|
||||||
|
)
|
||||||
|
|
||||||
|
line = 1
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
access_key = credentials.get('aws_access_key_id')
|
||||||
|
secret_key = credentials.get('aws_secret_access_key')
|
||||||
|
aws_region = credentials.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
if access_key and secret_key:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
|
||||||
|
line = 2
|
||||||
|
|
||||||
|
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||||
|
candidate_docs = []
|
||||||
|
|
||||||
|
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
|
||||||
|
for idx in range(len(scores)):
|
||||||
|
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
|
||||||
|
|
||||||
|
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||||
|
|
||||||
|
line = 3
|
||||||
|
rerank_documents = []
|
||||||
|
for idx, result in enumerate(candidate_docs):
|
||||||
|
rerank_document = RerankDocument(
|
||||||
|
index=idx,
|
||||||
|
text=result.get('content'),
|
||||||
|
score=result.get('score', -100.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if score_threshold is not None:
|
||||||
|
if rerank_document.score >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
else:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
return RerankResult(
|
||||||
|
model=model,
|
||||||
|
docs=rerank_documents
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'Exception {e}, line : {line}')
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
InvokeBadRequestError,
|
||||||
|
KeyError,
|
||||||
|
ValueError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
model_properties={ },
|
||||||
|
parameter_rules=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -0,0 +1,17 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SageMakerProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
pass
|
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
provider: sagemaker
|
||||||
|
label:
|
||||||
|
zh_Hans: Sagemaker
|
||||||
|
en_US: Sagemaker
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.png
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.png
|
||||||
|
description:
|
||||||
|
en_US: Customized model on Sagemaker
|
||||||
|
zh_Hans: Sagemaker上的私有化部署的模型
|
||||||
|
background: "#ECE9E3"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: How to deploy customized model on Sagemaker
|
||||||
|
zh_Hans: 如何在Sagemaker上的私有化部署的模型
|
||||||
|
url:
|
||||||
|
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
|
||||||
|
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
configurate_methods:
|
||||||
|
- customizable-model
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model Name
|
||||||
|
zh_Hans: 模型名称
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your model name
|
||||||
|
zh_Hans: 输入模型名称
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: mode
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
label:
|
||||||
|
en_US: Completion mode
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: chat
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 选择对话类型
|
||||||
|
en_US: Select completion mode
|
||||||
|
options:
|
||||||
|
- value: completion
|
||||||
|
label:
|
||||||
|
en_US: Completion
|
||||||
|
zh_Hans: 补全
|
||||||
|
- value: chat
|
||||||
|
label:
|
||||||
|
en_US: Chat
|
||||||
|
zh_Hans: 对话
|
||||||
|
- variable: sagemaker_endpoint
|
||||||
|
label:
|
||||||
|
en_US: sagemaker endpoint
|
||||||
|
type: text-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 请输出你的Sagemaker推理端点
|
||||||
|
en_US: Enter your Sagemaker Inference endpoint
|
||||||
|
- variable: aws_access_key_id
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
|
||||||
|
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
|
||||||
|
type: secret-input
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Access Key
|
||||||
|
zh_Hans: 在此输入您的 Access Key
|
||||||
|
- variable: aws_secret_access_key
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Secret Access Key
|
||||||
|
zh_Hans: Secret Access Key
|
||||||
|
type: secret-input
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Secret Access Key
|
||||||
|
zh_Hans: 在此输入您的 Secret Access Key
|
||||||
|
- variable: aws_region
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: AWS Region
|
||||||
|
zh_Hans: AWS 地区
|
||||||
|
type: select
|
||||||
|
default: us-east-1
|
||||||
|
options:
|
||||||
|
- value: us-east-1
|
||||||
|
label:
|
||||||
|
en_US: US East (N. Virginia)
|
||||||
|
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||||
|
- value: us-west-2
|
||||||
|
label:
|
||||||
|
en_US: US West (Oregon)
|
||||||
|
zh_Hans: 美国西部 (俄勒冈州)
|
||||||
|
- value: ap-southeast-1
|
||||||
|
label:
|
||||||
|
en_US: Asia Pacific (Singapore)
|
||||||
|
zh_Hans: 亚太地区 (新加坡)
|
||||||
|
- value: ap-northeast-1
|
||||||
|
label:
|
||||||
|
en_US: Asia Pacific (Tokyo)
|
||||||
|
zh_Hans: 亚太地区 (东京)
|
||||||
|
- value: eu-central-1
|
||||||
|
label:
|
||||||
|
en_US: Europe (Frankfurt)
|
||||||
|
zh_Hans: 欧洲 (法兰克福)
|
||||||
|
- value: us-gov-west-1
|
||||||
|
label:
|
||||||
|
en_US: AWS GovCloud (US-West)
|
||||||
|
zh_Hans: AWS GovCloud (US-West)
|
||||||
|
- value: ap-southeast-2
|
||||||
|
label:
|
||||||
|
en_US: Asia Pacific (Sydney)
|
||||||
|
zh_Hans: 亚太地区 (悉尼)
|
||||||
|
- value: cn-north-1
|
||||||
|
label:
|
||||||
|
en_US: AWS Beijing (cn-north-1)
|
||||||
|
zh_Hans: 中国北京 (cn-north-1)
|
||||||
|
- value: cn-northwest-1
|
||||||
|
label:
|
||||||
|
en_US: AWS Ningxia (cn-northwest-1)
|
||||||
|
zh_Hans: 中国宁夏 (cn-northwest-1)
|
@ -0,0 +1,214 @@
|
|||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
|
||||||
|
BATCH_SIZE = 20
|
||||||
|
CONTEXT_SIZE=8192
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def batch_generator(generator, batch_size):
|
||||||
|
while True:
|
||||||
|
batch = list(itertools.islice(generator, batch_size))
|
||||||
|
if not batch:
|
||||||
|
break
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Cohere text embedding model.
|
||||||
|
"""
|
||||||
|
sagemaker_client: Any = None
|
||||||
|
|
||||||
|
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
|
||||||
|
response_model = sm_client.invoke_endpoint(
|
||||||
|
EndpointName=endpoint_name,
|
||||||
|
Body=json.dumps(
|
||||||
|
{
|
||||||
|
"inputs": content_list,
|
||||||
|
"parameters": {},
|
||||||
|
"is_query" : False,
|
||||||
|
"instruction" : ''
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ContentType="application/json",
|
||||||
|
)
|
||||||
|
json_str = response_model['Body'].read().decode('utf8')
|
||||||
|
json_obj = json.loads(json_str)
|
||||||
|
embeddings = json_obj['embeddings']
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
texts: list[str], user: Optional[str] = None) \
|
||||||
|
-> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
# get model properties
|
||||||
|
try:
|
||||||
|
line = 1
|
||||||
|
if not self.sagemaker_client:
|
||||||
|
access_key = credentials.get('aws_access_key_id')
|
||||||
|
secret_key = credentials.get('aws_secret_access_key')
|
||||||
|
aws_region = credentials.get('aws_region')
|
||||||
|
if aws_region:
|
||||||
|
if access_key and secret_key:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||||
|
else:
|
||||||
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||||
|
|
||||||
|
line = 2
|
||||||
|
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||||
|
|
||||||
|
line = 3
|
||||||
|
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
|
||||||
|
|
||||||
|
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
line = 4
|
||||||
|
for batch in batches:
|
||||||
|
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
|
||||||
|
all_embeddings.extend(embeddings)
|
||||||
|
|
||||||
|
line = 5
|
||||||
|
# calc usage
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
tokens=0 # It's not SAAS API, usage is meaningless
|
||||||
|
)
|
||||||
|
line = 6
|
||||||
|
|
||||||
|
return TextEmbeddingResult(
|
||||||
|
embeddings=all_embeddings,
|
||||||
|
usage=usage,
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'Exception {e}, line : {line}')
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print("validate_credentials ok....")
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
price_type=PriceType.INPUT,
|
||||||
|
tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
KeyError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
|
||||||
|
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
|
||||||
|
},
|
||||||
|
parameter_rules=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
Binary file not shown.
After Width: | Height: | Size: 9.0 KiB |
Binary file not shown.
After Width: | Height: | Size: 1.9 KiB |
@ -0,0 +1,6 @@
|
|||||||
|
- step-1-8k
|
||||||
|
- step-1-32k
|
||||||
|
- step-1-128k
|
||||||
|
- step-1-256k
|
||||||
|
- step-1v-8k
|
||||||
|
- step-1v-32k
|
328
api/core/model_runtime/model_providers/stepfun/llm/llm.py
Normal file
328
api/core/model_runtime/model_providers/stepfun/llm/llm.py
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
ParameterType,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
|
-> Union[LLMResult, Generator]:
|
||||||
|
self._add_custom_parameters(credentials)
|
||||||
|
self._add_function_call(model, credentials)
|
||||||
|
user = user[:32] if user else None
|
||||||
|
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
self._add_custom_parameters(credentials)
|
||||||
|
super().validate_credentials(model, credentials)
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
return AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model, zh_Hans=model),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||||
|
if credentials.get('function_calling_type') == 'tool_call'
|
||||||
|
else [],
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)),
|
||||||
|
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||||
|
},
|
||||||
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name='temperature',
|
||||||
|
use_template='temperature',
|
||||||
|
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name='max_tokens',
|
||||||
|
use_template='max_tokens',
|
||||||
|
default=512,
|
||||||
|
min=1,
|
||||||
|
max=int(credentials.get('max_tokens', 1024)),
|
||||||
|
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
|
||||||
|
type=ParameterType.INT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name='top_p',
|
||||||
|
use_template='top_p',
|
||||||
|
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||||
|
credentials['mode'] = 'chat'
|
||||||
|
credentials['endpoint_url'] = 'https://api.stepfun.com/v1'
|
||||||
|
|
||||||
|
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||||
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
if model_schema and {
|
||||||
|
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
|
||||||
|
}.intersection(model_schema.features or []):
|
||||||
|
credentials['function_calling_type'] = 'tool_call'
|
||||||
|
|
||||||
|
def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict:
|
||||||
|
"""
|
||||||
|
Convert PromptMessage to dict for OpenAI API format
|
||||||
|
"""
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
else:
|
||||||
|
sub_messages = []
|
||||||
|
for message_content in message.content:
|
||||||
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
|
message_content = cast(PromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "text",
|
||||||
|
"text": message_content.data
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
|
message_content = cast(ImagePromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": message_content.data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
message_dict = {"role": "user", "content": sub_messages}
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if message.tool_calls:
|
||||||
|
message_dict["tool_calls"] = []
|
||||||
|
for function_call in message.tool_calls:
|
||||||
|
message_dict["tool_calls"].append({
|
||||||
|
"id": function_call.id,
|
||||||
|
"type": function_call.type,
|
||||||
|
"function": {
|
||||||
|
"name": function_call.function.name,
|
||||||
|
"arguments": function_call.function.arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message = cast(SystemPromptMessage, message)
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
if message.name:
|
||||||
|
message_dict["name"] = message.name
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||||
|
"""
|
||||||
|
Extract tool calls from response
|
||||||
|
|
||||||
|
:param response_tool_calls: response tool calls
|
||||||
|
:return: list of tool calls
|
||||||
|
"""
|
||||||
|
tool_calls = []
|
||||||
|
if response_tool_calls:
|
||||||
|
for response_tool_call in response_tool_calls:
|
||||||
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
|
||||||
|
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
||||||
|
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
||||||
|
function=function
|
||||||
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||||
|
prompt_messages: list[PromptMessage]) -> Generator:
|
||||||
|
"""
|
||||||
|
Handle llm stream response
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param response: streamed response
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:return: llm response chunk generator
|
||||||
|
"""
|
||||||
|
full_assistant_content = ''
|
||||||
|
chunk_index = 0
|
||||||
|
|
||||||
|
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||||
|
-> LLMResultChunk:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||||
|
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
return LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=index,
|
||||||
|
message=message,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||||
|
finish_reason = "Unknown"
|
||||||
|
|
||||||
|
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||||
|
def get_tool_call(tool_name: str):
|
||||||
|
if not tool_name:
|
||||||
|
return tools_calls[-1]
|
||||||
|
|
||||||
|
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
||||||
|
if tool_call is None:
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id='',
|
||||||
|
type='',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
|
||||||
|
)
|
||||||
|
tools_calls.append(tool_call)
|
||||||
|
|
||||||
|
return tool_call
|
||||||
|
|
||||||
|
for new_tool_call in new_tool_calls:
|
||||||
|
# get tool call
|
||||||
|
tool_call = get_tool_call(new_tool_call.function.name)
|
||||||
|
# update tool call
|
||||||
|
if new_tool_call.id:
|
||||||
|
tool_call.id = new_tool_call.id
|
||||||
|
if new_tool_call.type:
|
||||||
|
tool_call.type = new_tool_call.type
|
||||||
|
if new_tool_call.function.name:
|
||||||
|
tool_call.function.name = new_tool_call.function.name
|
||||||
|
if new_tool_call.function.arguments:
|
||||||
|
tool_call.function.arguments += new_tool_call.function.arguments
|
||||||
|
|
||||||
|
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
||||||
|
if chunk:
|
||||||
|
# ignore sse comments
|
||||||
|
if chunk.startswith(':'):
|
||||||
|
continue
|
||||||
|
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||||
|
chunk_json = None
|
||||||
|
try:
|
||||||
|
chunk_json = json.loads(decoded_chunk)
|
||||||
|
# stream ended
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
yield create_final_llm_result_chunk(
|
||||||
|
index=chunk_index + 1,
|
||||||
|
message=AssistantPromptMessage(content=""),
|
||||||
|
finish_reason="Non-JSON encountered."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk_json['choices'][0]
|
||||||
|
finish_reason = chunk_json['choices'][0].get('finish_reason')
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
|
if 'delta' in choice:
|
||||||
|
delta = choice['delta']
|
||||||
|
delta_content = delta.get('content')
|
||||||
|
|
||||||
|
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||||
|
# assistant_message_function_call = delta.delta.function_call
|
||||||
|
|
||||||
|
# extract tool calls from response
|
||||||
|
if assistant_message_tool_calls:
|
||||||
|
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
|
increase_tool_call(tool_calls)
|
||||||
|
|
||||||
|
if delta_content is None or delta_content == '':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=delta_content,
|
||||||
|
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||||
|
)
|
||||||
|
|
||||||
|
full_assistant_content += delta_content
|
||||||
|
elif 'text' in choice:
|
||||||
|
choice_text = choice.get('text', '')
|
||||||
|
if choice_text == '':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||||
|
full_assistant_content += choice_text
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# check payload indicator for completion
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=chunk_index,
|
||||||
|
message=assistant_prompt_message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
|
if tools_calls:
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=chunk_index,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
tool_calls=tools_calls,
|
||||||
|
content=""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield create_final_llm_result_chunk(
|
||||||
|
index=chunk_index,
|
||||||
|
message=AssistantPromptMessage(content=""),
|
||||||
|
finish_reason=finish_reason
|
||||||
|
)
|
@ -0,0 +1,25 @@
|
|||||||
|
model: step-1-128k
|
||||||
|
label:
|
||||||
|
zh_Hans: step-1-128k
|
||||||
|
en_US: step-1-128k
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 1024
|
||||||
|
min: 1
|
||||||
|
max: 128000
|
||||||
|
pricing:
|
||||||
|
input: '0.04'
|
||||||
|
output: '0.20'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
@ -0,0 +1,25 @@
|
|||||||
|
model: step-1-256k
|
||||||
|
label:
|
||||||
|
zh_Hans: step-1-256k
|
||||||
|
en_US: step-1-256k
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 256000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 1024
|
||||||
|
min: 1
|
||||||
|
max: 256000
|
||||||
|
pricing:
|
||||||
|
input: '0.095'
|
||||||
|
output: '0.300'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
@ -0,0 +1,28 @@
|
|||||||
|
model: step-1-32k
|
||||||
|
label:
|
||||||
|
zh_Hans: step-1-32k
|
||||||
|
en_US: step-1-32k
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- multi-tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 1024
|
||||||
|
min: 1
|
||||||
|
max: 32000
|
||||||
|
pricing:
|
||||||
|
input: '0.015'
|
||||||
|
output: '0.070'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
@ -0,0 +1,28 @@
|
|||||||
|
model: step-1-8k
|
||||||
|
label:
|
||||||
|
zh_Hans: step-1-8k
|
||||||
|
en_US: step-1-8k
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- multi-tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 512
|
||||||
|
min: 1
|
||||||
|
max: 8000
|
||||||
|
pricing:
|
||||||
|
input: '0.005'
|
||||||
|
output: '0.020'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
@ -0,0 +1,25 @@
|
|||||||
|
model: step-1v-32k
|
||||||
|
label:
|
||||||
|
zh_Hans: step-1v-32k
|
||||||
|
en_US: step-1v-32k
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 1024
|
||||||
|
min: 1
|
||||||
|
max: 32000
|
||||||
|
pricing:
|
||||||
|
input: '0.015'
|
||||||
|
output: '0.070'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
@ -0,0 +1,25 @@
|
|||||||
|
model: step-1v-8k
|
||||||
|
label:
|
||||||
|
zh_Hans: step-1v-8k
|
||||||
|
en_US: step-1v-8k
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 512
|
||||||
|
min: 1
|
||||||
|
max: 8192
|
||||||
|
pricing:
|
||||||
|
input: '0.005'
|
||||||
|
output: '0.020'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
30
api/core/model_runtime/model_providers/stepfun/stepfun.py
Normal file
30
api/core/model_runtime/model_providers/stepfun/stepfun.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StepfunProvider(ModelProvider):
|
||||||
|
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
model_instance.validate_credentials(
|
||||||
|
model='step-1-8k',
|
||||||
|
credentials=credentials
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||||
|
raise ex
|
81
api/core/model_runtime/model_providers/stepfun/stepfun.yaml
Normal file
81
api/core/model_runtime/model_providers/stepfun/stepfun.yaml
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
provider: stepfun
|
||||||
|
label:
|
||||||
|
zh_Hans: 阶跃星辰
|
||||||
|
en_US: Stepfun
|
||||||
|
description:
|
||||||
|
en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k
|
||||||
|
zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.png
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.png
|
||||||
|
background: "#FFFFFF"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your API Key from stepfun
|
||||||
|
zh_Hans: 从 stepfun 获取 API Key
|
||||||
|
url:
|
||||||
|
en_US: https://platform.stepfun.com/interface-key
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
- customizable-model
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model Name
|
||||||
|
zh_Hans: 模型名称
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your model name
|
||||||
|
zh_Hans: 输入模型名称
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
|
- variable: context_size
|
||||||
|
label:
|
||||||
|
zh_Hans: 模型上下文长度
|
||||||
|
en_US: Model context size
|
||||||
|
required: true
|
||||||
|
type: text-input
|
||||||
|
default: '8192'
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的模型上下文长度
|
||||||
|
en_US: Enter your Model context size
|
||||||
|
- variable: max_tokens
|
||||||
|
label:
|
||||||
|
zh_Hans: 最大 token 上限
|
||||||
|
en_US: Upper bound for max tokens
|
||||||
|
default: '8192'
|
||||||
|
type: text-input
|
||||||
|
- variable: function_calling_type
|
||||||
|
label:
|
||||||
|
en_US: Function calling
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: no_call
|
||||||
|
options:
|
||||||
|
- value: no_call
|
||||||
|
label:
|
||||||
|
en_US: Not supported
|
||||||
|
zh_Hans: 不支持
|
||||||
|
- value: tool_call
|
||||||
|
label:
|
||||||
|
en_US: Tool Call
|
||||||
|
zh_Hans: Tool Call
|
@ -35,3 +35,4 @@ parameter_rules:
|
|||||||
zh_Hans: 禁用模型自行进行外部搜索。
|
zh_Hans: 禁用模型自行进行外部搜索。
|
||||||
en_US: Disable the model to perform external search.
|
en_US: Disable the model to perform external search.
|
||||||
required: false
|
required: false
|
||||||
|
deprecated: true
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
model: ernie-4.0-8k-Latest
|
model: ernie-4.0-8k-latest
|
||||||
label:
|
label:
|
||||||
en_US: Ernie-4.0-8K-Latest
|
en_US: Ernie-4.0-8K-Latest
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
model: ernie-4.0-turbo-8k
|
||||||
|
label:
|
||||||
|
en_US: Ernie-4.0-turbo-8K
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
min: 0.1
|
||||||
|
max: 1.0
|
||||||
|
default: 0.8
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 1024
|
||||||
|
min: 2
|
||||||
|
max: 2048
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
default: 1.0
|
||||||
|
min: 1.0
|
||||||
|
max: 2.0
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
- name: disable_search
|
||||||
|
label:
|
||||||
|
zh_Hans: 禁用搜索
|
||||||
|
en_US: Disable Search
|
||||||
|
type: boolean
|
||||||
|
help:
|
||||||
|
zh_Hans: 禁用模型自行进行外部搜索。
|
||||||
|
en_US: Disable the model to perform external search.
|
||||||
|
required: false
|
@ -28,3 +28,4 @@ parameter_rules:
|
|||||||
default: 1.0
|
default: 1.0
|
||||||
min: 1.0
|
min: 1.0
|
||||||
max: 2.0
|
max: 2.0
|
||||||
|
deprecated: true
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
model: ernie-character-8k-0321
|
||||||
|
label:
|
||||||
|
en_US: ERNIE-Character-8K
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
min: 0.1
|
||||||
|
max: 1.0
|
||||||
|
default: 0.95
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
min: 0
|
||||||
|
max: 1.0
|
||||||
|
default: 0.7
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 1024
|
||||||
|
min: 2
|
||||||
|
max: 1024
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
default: 1.0
|
||||||
|
min: 1.0
|
||||||
|
max: 2.0
|
@ -28,3 +28,4 @@ parameter_rules:
|
|||||||
default: 1.0
|
default: 1.0
|
||||||
min: 1.0
|
min: 1.0
|
||||||
max: 2.0
|
max: 2.0
|
||||||
|
deprecated: true
|
||||||
|
@ -28,3 +28,4 @@ parameter_rules:
|
|||||||
default: 1.0
|
default: 1.0
|
||||||
min: 1.0
|
min: 1.0
|
||||||
max: 2.0
|
max: 2.0
|
||||||
|
deprecated: true
|
||||||
|
@ -97,6 +97,7 @@ class BaiduAccessToken:
|
|||||||
baidu_access_tokens_lock.release()
|
baidu_access_tokens_lock.release()
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
class ErnieMessage:
|
class ErnieMessage:
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
USER = 'user'
|
USER = 'user'
|
||||||
@ -137,7 +138,9 @@ class ErnieBotModel:
|
|||||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||||
|
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||||
|
'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||||
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,7 +152,8 @@ class ErnieBotModel:
|
|||||||
'ernie-3.5-8k-1222',
|
'ernie-3.5-8k-1222',
|
||||||
'ernie-3.5-4k-0205',
|
'ernie-3.5-4k-0205',
|
||||||
'ernie-3.5-128k',
|
'ernie-3.5-128k',
|
||||||
'ernie-4.0-8k'
|
'ernie-4.0-8k',
|
||||||
|
'ernie-4.0-turbo-8k',
|
||||||
'ernie-4.0-turbo-8k-preview'
|
'ernie-4.0-turbo-8k-preview'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user