This commit is contained in:
takatost 2024-07-22 19:57:32 +08:00
commit a603e01f5e
372 changed files with 9779 additions and 1678 deletions

1
.gitignore vendored
View File

@ -174,5 +174,6 @@ sdks/python-client/dify_client.egg-info
.vscode/* .vscode/*
!.vscode/launch.json !.vscode/launch.json
pyrightconfig.json pyrightconfig.json
api/.vscode
.idea/ .idea/

View File

@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install. Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot. Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
### 5. Visit dify in your browser ### 5. Visit dify in your browser

View File

@ -77,7 +77,7 @@ Dify 依赖以下工具和库:
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。 Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
查看 [安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq) 以获取常见问题列表和故障排除步骤。 查看 [安装常见问题解答](https://docs.dify.ai/v/zh-hans/learn-more/faq/install-faq) 以获取常见问题列表和故障排除步骤。
### 5. 在浏览器中访问 Dify ### 5. 在浏览器中访问 Dify

View File

@ -82,7 +82,7 @@ Dify はバックエンドとフロントエンドから構成されています
まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。 まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。 次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。 よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/v/japanese/learn-more/faq/install-faq) を確認してください。
### 5. ブラウザで dify にアクセスする ### 5. ブラウザで dify にアクセスする

View File

@ -256,3 +256,7 @@ WORKFLOW_CALL_MAX_DEPTH=5
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1

View File

@ -1,7 +1,5 @@
import os import os
from configs import dify_config
if os.environ.get("DEBUG", "false").lower() != 'true': if os.environ.get("DEBUG", "false").lower() != 'true':
from gevent import monkey from gevent import monkey
@ -23,7 +21,9 @@ from flask import Flask, Response, request
from flask_cors import CORS from flask_cors import CORS
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
import contexts
from commands import register_commands from commands import register_commands
from configs import dify_config
# DO NOT REMOVE BELOW # DO NOT REMOVE BELOW
from events import event_handlers from events import event_handlers
@ -181,7 +181,10 @@ def load_user_from_request(request_from_flask_login):
decoded = PassportService().verify(auth_token) decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id') user_id = decoded.get('user_id')
return AccountService.load_logged_in_account(account_id=user_id, token=auth_token) account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
if account:
contexts.tenant_id.set(account.current_tenant_id)
return account
@login_manager.unauthorized_handler @login_manager.unauthorized_handler

View File

@ -23,6 +23,7 @@ class SecurityConfig(BaseSettings):
default=24, default=24,
) )
class AppExecutionConfig(BaseSettings): class AppExecutionConfig(BaseSettings):
""" """
App Execution configs App Execution configs
@ -405,7 +406,6 @@ class DataSetConfig(BaseSettings):
default=False, default=False,
) )
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """
Workspace configs Workspace configs
@ -435,6 +435,13 @@ class ImageFormatConfig(BaseSettings):
) )
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description='the time of the celery scheduler, default to 1 day',
default=1,
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
@ -462,5 +469,6 @@ class FeatureConfig(
# hosted services config # hosted services config
HostedServiceConfig, HostedServiceConfig,
CeleryBeatConfig,
): ):
pass pass

View File

@ -79,7 +79,7 @@ class HostedAzureOpenAiConfig(BaseSettings):
default=False, default=False,
) )
HOSTED_OPENAI_API_KEY: Optional[str] = Field( HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description='', description='',
default=None, default=None,
) )

View File

@ -1,4 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt from pydantic import BaseModel, Field, PositiveInt
@ -8,32 +7,32 @@ class MyScaleConfig(BaseModel):
MyScale configs MyScale configs
""" """
MYSCALE_HOST: Optional[str] = Field( MYSCALE_HOST: str = Field(
description='MyScale host', description='MyScale host',
default=None, default='localhost',
) )
MYSCALE_PORT: Optional[PositiveInt] = Field( MYSCALE_PORT: PositiveInt = Field(
description='MyScale port', description='MyScale port',
default=8123, default=8123,
) )
MYSCALE_USER: Optional[str] = Field( MYSCALE_USER: str = Field(
description='MyScale user', description='MyScale user',
default=None, default='default',
) )
MYSCALE_PASSWORD: Optional[str] = Field( MYSCALE_PASSWORD: str = Field(
description='MyScale password', description='MyScale password',
default=None, default='',
) )
MYSCALE_DATABASE: Optional[str] = Field( MYSCALE_DATABASE: str = Field(
description='MyScale database name', description='MyScale database name',
default=None, default='default',
) )
MYSCALE_FTS_PARAMS: Optional[str] = Field( MYSCALE_FTS_PARAMS: str = Field(
description='MyScale fts index parameters', description='MyScale fts index parameters',
default=None, default='',
) )

View File

@ -0,0 +1,2 @@
# TODO: Update all string in code to use this constant
HIDDEN_VALUE = '[__HIDDEN__]'

3
api/contexts/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from contextvars import ContextVar
tenant_id: ContextVar[str] = ContextVar('tenant_id')

View File

@ -212,7 +212,7 @@ class AppCopyApi(Resource):
parser.add_argument('icon_background', type=str, location='json') parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model) data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
app = AppDslService.import_and_create_new_app( app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
data=data, data=data,
@ -234,8 +234,13 @@ class AppExportApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
args = parser.parse_args()
return { return {
"data": AppDslService.export_dsl(app_model=app_model) "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
} }

View File

@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import factory
from core.errors.error import AppInvokeQuotaExceededError from core.errors.error import AppInvokeQuotaExceededError
from fields.workflow_fields import workflow_fields from fields.workflow_fields import workflow_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
@ -41,7 +42,7 @@ class DraftWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model) workflow = workflow_service.get_draft_workflow(app_model=app_model)
@ -64,13 +65,15 @@ class DraftWorkflowApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
content_type = request.headers.get('Content-Type') content_type = request.headers.get('Content-Type', '')
if 'application/json' in content_type: if 'application/json' in content_type:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json') parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
parser.add_argument('hash', type=str, required=False, location='json') parser.add_argument('hash', type=str, required=False, location='json')
# TODO: set this to required=True after frontend is updated
parser.add_argument('environment_variables', type=list, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
elif 'text/plain' in content_type: elif 'text/plain' in content_type:
try: try:
@ -84,7 +87,8 @@ class DraftWorkflowApi(Resource):
args = { args = {
'graph': data.get('graph'), 'graph': data.get('graph'),
'features': data.get('features'), 'features': data.get('features'),
'hash': data.get('hash') 'hash': data.get('hash'),
'environment_variables': data.get('environment_variables')
} }
except json.JSONDecodeError: except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400 return {'message': 'Invalid JSON data'}, 400
@ -94,12 +98,15 @@ class DraftWorkflowApi(Resource):
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
environment_variables_list = args.get('environment_variables') or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
workflow = workflow_service.sync_draft_workflow( workflow = workflow_service.sync_draft_workflow(
app_model=app_model, app_model=app_model,
graph=args.get('graph'), graph=args['graph'],
features=args.get('features'), features=args['features'],
unique_hash=args.get('hash'), unique_hash=args.get('hash'),
account=current_user account=current_user,
environment_variables=environment_variables,
) )
except WorkflowHashNotEqualError: except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync() raise DraftWorkflowNotSync()

View File

@ -1,10 +1,11 @@
import flask_restful import flask_restful
from flask import current_app, request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal, marshal_with, reqparse from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
@ -530,7 +531,7 @@ class DatasetApiBaseUrlApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
return { return {
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
else request.host_url.rstrip('/')) + '/v1' else request.host_url.rstrip('/')) + '/v1'
} }
@ -540,20 +541,20 @@ class DatasetRetrievalSettingApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
vector_type = current_app.config['VECTOR_STORE'] vector_type = dify_config.VECTOR_STORE
match vector_type: match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE: case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH RetrievalMethod.SEMANTIC_SEARCH.value
] ]
} }
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE: case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH, RetrievalMethod.HYBRID_SEARCH.value,
] ]
} }
case _: case _:
@ -569,15 +570,15 @@ class DatasetRetrievalSettingMockApi(Resource):
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE: case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH RetrievalMethod.SEMANTIC_SEARCH.value
] ]
} }
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE: case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
return { return {
'retrieval_method': [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH, RetrievalMethod.HYBRID_SEARCH.value,
] ]
} }
case _: case _:

View File

@ -75,7 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
) )
if last_id is not None: if last_id is not None:
last_segment = DocumentSegment.query.get(str(last_id)) last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment: if last_segment:
query = query.filter( query = query.filter(
DocumentSegment.position > last_segment.position) DocumentSegment.position > last_segment.position)

View File

@ -1,8 +1,9 @@
from flask import current_app, request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with from flask_restful import Resource, marshal_with
import services import services
from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import ( from controllers.console.datasets.error import (
FileTooLargeError, FileTooLargeError,
@ -26,9 +27,9 @@ class FileApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(upload_config_fields) @marshal_with(upload_config_fields)
def get(self): def get(self):
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT") batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
return { return {
'file_size_limit': file_size_limit, 'file_size_limit': file_size_limit,
'batch_count_limit': batch_count_limit, 'batch_count_limit': batch_count_limit,
@ -76,7 +77,7 @@ class FileSupportTypeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
etl_type = current_app.config['ETL_TYPE'] etl_type = dify_config.ETL_TYPE
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
return {'allowed_extensions': allowed_extensions} return {'allowed_extensions': allowed_extensions}

View File

@ -78,10 +78,12 @@ class ChatTextApi(InstalledAppResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json') parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json') parser.add_argument('voice', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json') parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args() args = parser.parse_args()
message_id = args.get('message_id') message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict): and app_model.workflow.features_dict):
@ -95,7 +97,8 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
voice=voice voice=voice,
text=text
) )
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@ -1,7 +1,7 @@
from flask import current_app
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with
from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
@ -78,7 +78,7 @@ class AppParameterApi(InstalledAppResource):
"transfer_methods": ["remote_url", "local_file"] "transfer_methods": ["remote_url", "local_file"]
}}), }}),
'system_parameters': { 'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
} }
} }

View File

@ -1,8 +1,9 @@
import os import os
from flask import current_app, session from flask import session
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import str_len from libs.helper import str_len
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
@ -40,7 +41,7 @@ class InitValidateAPI(Resource):
return {'result': 'success'}, 201 return {'result': 'success'}, 201
def get_init_validate_status(): def get_init_validate_status():
if current_app.config['EDITION'] == 'SELF_HOSTED': if dify_config.EDITION == 'SELF_HOSTED':
if os.environ.get('INIT_PASSWORD'): if os.environ.get('INIT_PASSWORD'):
return session.get('is_init_validated') or DifySetup.query.first() return session.get('is_init_validated') or DifySetup.query.first()

View File

@ -1,8 +1,9 @@
from functools import wraps from functools import wraps
from flask import current_app, request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import email, get_remote_ip, str_len from libs.helper import email, get_remote_ip, str_len
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup from models.model import DifySetup
@ -17,7 +18,7 @@ from .wraps import only_edition_self_hosted
class SetupApi(Resource): class SetupApi(Resource):
def get(self): def get(self):
if current_app.config['EDITION'] == 'SELF_HOSTED': if dify_config.EDITION == 'SELF_HOSTED':
setup_status = get_setup_status() setup_status = get_setup_status()
if setup_status: if setup_status:
return { return {
@ -77,7 +78,7 @@ def setup_required(view):
def get_setup_status(): def get_setup_status():
if current_app.config['EDITION'] == 'SELF_HOSTED': if dify_config.EDITION == 'SELF_HOSTED':
return DifySetup.query.first() return DifySetup.query.first()
else: else:
return True return True

View File

@ -3,9 +3,10 @@ import json
import logging import logging
import requests import requests
from flask import current_app
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config
from . import api from . import api
@ -15,16 +16,16 @@ class VersionApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('current_version', type=str, required=True, location='args') parser.add_argument('current_version', type=str, required=True, location='args')
args = parser.parse_args() args = parser.parse_args()
check_update_url = current_app.config['CHECK_UPDATE_URL'] check_update_url = dify_config.CHECK_UPDATE_URL
result = { result = {
'version': current_app.config['CURRENT_VERSION'], 'version': dify_config.CURRENT_VERSION,
'release_date': '', 'release_date': '',
'release_notes': '', 'release_notes': '',
'can_auto_update': False, 'can_auto_update': False,
'features': { 'features': {
'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'], 'can_replace_logo': dify_config.CAN_REPLACE_LOGO,
'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED'] 'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED
} }
} }

View File

@ -1,10 +1,11 @@
import datetime import datetime
import pytz import pytz
from flask import current_app, request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse
from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
@ -36,7 +37,7 @@ class AccountInitApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
if current_app.config['EDITION'] == 'CLOUD': if dify_config.EDITION == 'CLOUD':
parser.add_argument('invitation_code', type=str, location='json') parser.add_argument('invitation_code', type=str, location='json')
parser.add_argument( parser.add_argument(
@ -45,7 +46,7 @@ class AccountInitApi(Resource):
required=True, location='json') required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
if current_app.config['EDITION'] == 'CLOUD': if dify_config.EDITION == 'CLOUD':
if not args['invitation_code']: if not args['invitation_code']:
raise ValueError('invitation_code is required') raise ValueError('invitation_code is required')

View File

@ -1,8 +1,8 @@
from flask import current_app
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, abort, marshal_with, reqparse from flask_restful import Resource, abort, marshal_with, reqparse
import services import services
from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
@ -48,7 +48,7 @@ class MemberInviteEmailApi(Resource):
inviter = current_user inviter = current_user
invitation_results = [] invitation_results = []
console_web_url = current_app.config.get("CONSOLE_WEB_URL") console_web_url = dify_config.CONSOLE_WEB_URL
for invitee_email in invitee_emails: for invitee_email in invitee_emails:
try: try:
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter) token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
@ -117,7 +117,7 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role): if not TenantAccountRole.is_valid_role(new_role):
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
member = Account.query.get(str(member_id)) member = db.session.get(Account, str(member_id))
if not member: if not member:
abort(404) abort(404)

View File

@ -1,10 +1,11 @@
import io import io
from flask import current_app, send_file from flask import send_file
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
@ -104,7 +105,7 @@ class ToolBuiltinProviderIconApi(Resource):
@setup_required @setup_required
def get(self, provider): def get(self, provider):
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider) icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = current_app.config.get('TOOL_ICON_CACHE_MAX_AGE') icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
class ToolApiProviderAddApi(Resource): class ToolApiProviderAddApi(Resource):

View File

@ -1,9 +1,10 @@
import json import json
from functools import wraps from functools import wraps
from flask import abort, current_app, request from flask import abort, request
from flask_login import current_user from flask_login import current_user
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.operation_service import OperationService from services.operation_service import OperationService
@ -26,7 +27,7 @@ def account_initialization_required(view):
def only_edition_cloud(view): def only_edition_cloud(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if current_app.config['EDITION'] != 'CLOUD': if dify_config.EDITION != 'CLOUD':
abort(404) abort(404)
return view(*args, **kwargs) return view(*args, **kwargs)
@ -37,7 +38,7 @@ def only_edition_cloud(view):
def only_edition_self_hosted(view): def only_edition_self_hosted(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if current_app.config['EDITION'] != 'SELF_HOSTED': if dify_config.EDITION != 'SELF_HOSTED':
abort(404) abort(404)
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -76,10 +76,12 @@ class TextApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json') parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json') parser.add_argument('voice', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json') parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args() args = parser.parse_args()
message_id = args.get('message_id') message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict): and app_model.workflow.features_dict):
@ -87,15 +89,15 @@ class TextApi(Resource):
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else: else:
try: try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
'voice')
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
end_user=end_user.external_user_id, end_user=end_user.external_user_id,
voice=voice voice=voice,
text=text
) )
return response return response

View File

@ -1,6 +1,6 @@
import logging import logging
from flask_restful import Resource, reqparse from flask_restful import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.service_api import api from controllers.service_api import api
@ -21,14 +21,43 @@ from core.errors.error import (
QuotaExceededError, QuotaExceededError,
) )
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper from libs import helper
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowRunApi(Resource): class WorkflowRunApi(Resource):
workflow_run_fields = {
'id': fields.String,
'workflow_id': fields.String,
'status': fields.String,
'inputs': fields.Raw,
'outputs': fields.Raw,
'error': fields.String,
'total_steps': fields.Integer,
'total_tokens': fields.Integer,
'created_at': fields.DateTime,
'finished_at': fields.DateTime,
'elapsed_time': fields.Float,
}
@validate_app_token
@marshal_with(workflow_run_fields)
def get(self, app_model: App, workflow_id: str):
"""
Get a workflow task running detail
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
return workflow_run
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser): def post(self, app_model: App, end_user: EndUser):
""" """
@ -88,5 +117,5 @@ class WorkflowTaskStopApi(Resource):
} }
api.add_resource(WorkflowRunApi, '/workflows/run') api.add_resource(WorkflowRunApi, '/workflows/run/<string:workflow_id>', '/workflows/run')
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop') api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')

View File

@ -74,10 +74,12 @@ class TextApi(WebApiResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json') parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json') parser.add_argument('voice', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json') parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args() args = parser.parse_args()
message_id = args.get('message_id') message_id = args.get('message_id', None)
text = args.get('text', None)
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict): and app_model.workflow.features_dict):
@ -94,7 +96,8 @@ class TextApi(WebApiResource):
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
end_user=end_user.external_user_id, end_user=end_user.external_user_id,
voice=voice voice=voice,
text=text
) )
return response return response

View File

@ -342,10 +342,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
""" """
tool_calls = [] tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls: for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != '':
args = json.loads(prompt_message.function.arguments)
tool_calls.append(( tool_calls.append((
prompt_message.id, prompt_message.id,
prompt_message.function.name, prompt_message.function.name,
json.loads(prompt_message.function.arguments), args,
)) ))
return tool_calls return tool_calls
@ -359,10 +363,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
""" """
tool_calls = [] tool_calls = []
for prompt_message in llm_result.message.tool_calls: for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != '':
args = json.loads(prompt_message.function.arguments)
tool_calls.append(( tool_calls.append((
prompt_message.id, prompt_message.id,
prompt_message.function.name, prompt_message.function.name,
json.loads(prompt_message.function.arguments), args,
)) ))
return tool_calls return tool_calls

View File

@ -1,6 +1,7 @@
from typing import Optional, Union from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom from core.app.app_config.entities import AppAdditionalFeatures
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
@ -10,37 +11,19 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
SuggestedQuestionsAfterAnswerConfigManager, SuggestedQuestionsAfterAnswerConfigManager,
) )
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import AppMode, AppModelConfig from models.model import AppMode
class BaseAppConfigManager: class BaseAppConfigManager:
@classmethod @classmethod
def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom, def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures:
app_model_config: Union[AppModelConfig, dict],
config_dict: Optional[dict] = None) -> dict:
"""
Convert app model config to config dict
:param config_from: app model config from
:param app_model_config: app model config
:param config_dict: app model config dict
:return:
"""
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
return config_dict
@classmethod
def convert_features(cls, config_dict: dict, app_mode: AppMode) -> AppAdditionalFeatures:
""" """
Convert app config to app model config Convert app config to app model config
:param config_dict: app config :param config_dict: app config
:param app_mode: app mode :param app_mode: app mode
""" """
config_dict = config_dict.copy() config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures() additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(

View File

@ -1,11 +1,12 @@
from typing import Optional from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import FileExtraConfig from core.app.app_config.entities import FileExtraConfig
class FileUploadConfigManager: class FileUploadConfigManager:
@classmethod @classmethod
def convert(cls, config: dict, is_vision: bool = True) -> Optional[FileExtraConfig]: def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
""" """
Convert model config to model config Convert model config to model config

View File

@ -3,13 +3,13 @@ from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager: class TextToSpeechConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> bool: def convert(cls, config: dict):
""" """
Convert model config to model config Convert model config to model config
:param config: model config args :param config: model config args
""" """
text_to_speech = False text_to_speech = None
text_to_speech_dict = config.get('text_to_speech') text_to_speech_dict = config.get('text_to_speech')
if text_to_speech_dict: if text_to_speech_dict:
if text_to_speech_dict.get('enabled'): if text_to_speech_dict.get('enabled'):

View File

@ -1,3 +1,4 @@
import contextvars
import logging import logging
import os import os
import threading import threading
@ -8,6 +9,7 @@ from typing import Union
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
@ -107,6 +109,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras=extras, extras=extras,
trace_manager=trace_manager trace_manager=trace_manager
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
@ -173,6 +176,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
inputs=args['inputs'] inputs=args['inputs']
) )
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
@ -225,6 +229,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
'queue_manager': queue_manager, 'queue_manager': queue_manager,
'conversation_id': conversation.id, 'conversation_id': conversation.id,
'message_id': message.id, 'message_id': message.id,
'user': user,
'context': contextvars.copy_context()
}) })
worker_thread.start() worker_thread.start()
@ -249,7 +255,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
message_id: str) -> None: message_id: str,
user: Account,
context: contextvars.Context) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -259,6 +267,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID :param message_id: message ID
:return: :return:
""" """
for var, val in context.items():
var.set(val)
with flask_app.app_context(): with flask_app.app_context():
try: try:
runner = AdvancedChatAppRunner() runner = AdvancedChatAppRunner()

View File

@ -1,7 +1,8 @@
import logging import logging
import os import os
import time import time
from typing import Optional, cast from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db from extensions.ext_database import db
@ -86,7 +88,7 @@ class AdvancedChatAppRunner(AppRunner):
db.session.close() db.session.close()
workflow_callbacks = [WorkflowEventTriggerCallback( workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager, queue_manager=queue_manager,
workflow=workflow workflow=workflow
)] )]
@ -160,7 +162,7 @@ class AdvancedChatAppRunner(AppRunner):
self, queue_manager: AppQueueManager, self, queue_manager: AppQueueManager,
app_record: App, app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity, app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: dict, inputs: Mapping[str, Any],
query: str, query: str,
message_id: str message_id: str
) -> bool: ) -> bool:

View File

@ -1,9 +1,11 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
@ -18,12 +20,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse _blocking_response_type = ChatbotAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
:return: :return:
""" """
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = { response = {
'event': 'message', 'event': 'message',
'task_id': blocking_response.task_id, 'task_id': blocking_response.task_id,
@ -39,7 +42,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -53,8 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
-> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -83,8 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk) yield json.dumps(response_chunk)
@classmethod @classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
-> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response

View File

@ -113,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._stream_generate_routes = self._get_stream_generate_routes() self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None self._conversation_name_generate_thread = None
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: def process(self):
""" """
Process generate task pipeline. Process generate task pipeline.
:return: :return:
@ -136,8 +136,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
-> ChatbotAppBlockingResponse:
""" """
Process blocking response. Process blocking response.
:return: :return:
@ -167,8 +166,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
raise Exception('Queue listening stopped unexpectedly.') raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
-> Generator[ChatbotAppStreamResponse, None, None]:
""" """
To stream response. To stream response.
:return: :return:

View File

@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
) )
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow from models.workflow import Workflow
class WorkflowEventTriggerCallback(BaseWorkflowCallback): class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager

View File

@ -1,7 +1,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator
from typing import Union from typing import Any, Union
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -15,44 +15,41 @@ class AppGenerateResponseConverter(ABC):
@classmethod @classmethod
def convert(cls, response: Union[ def convert(cls, response: Union[
AppBlockingResponse, AppBlockingResponse,
Generator[AppStreamResponse, None, None] Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom) -> Union[ ], invoke_from: InvokeFrom):
dict,
Generator[str, None, None]
]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, cls._blocking_response_type): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)
else: else:
def _generate(): def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response): for chunk in cls.convert_stream_full_response(response):
if chunk == 'ping': if chunk == 'ping':
yield f'event: {chunk}\n\n' yield f'event: {chunk}\n\n'
else: else:
yield f'data: {chunk}\n\n' yield f'data: {chunk}\n\n'
return _generate() return _generate_full_response()
else: else:
if isinstance(response, cls._blocking_response_type): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response) return cls.convert_blocking_simple_response(response)
else: else:
def _generate(): def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response): for chunk in cls.convert_stream_simple_response(response):
if chunk == 'ping': if chunk == 'ping':
yield f'event: {chunk}\n\n' yield f'event: {chunk}\n\n'
else: else:
yield f'data: {chunk}\n\n' yield f'data: {chunk}\n\n'
return _generate() return _generate_simple_response()
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict: def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@abstractmethod @abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict: def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -68,7 +65,7 @@ class AppGenerateResponseConverter(ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def _get_simple_metadata(cls, metadata: dict) -> dict: def _get_simple_metadata(cls, metadata: dict[str, Any]):
""" """
Get simple metadata. Get simple metadata.
:param metadata: metadata :param metadata: metadata

View File

@ -1,3 +1,4 @@
import contextvars
import logging import logging
import os import os
import threading import threading
@ -8,6 +9,7 @@ from typing import Union
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
@ -38,7 +40,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
call_depth: int = 0, call_depth: int = 0,
) -> Union[dict, Generator[dict, None, None]]: ):
""" """
Generate App response. Generate App response.
@ -86,6 +88,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
call_depth=call_depth, call_depth=call_depth,
trace_manager=trace_manager trace_manager=trace_manager
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
@ -126,7 +129,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
worker_thread = threading.Thread(target=self._generate_worker, kwargs={ worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity, 'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager 'queue_manager': queue_manager,
'context': contextvars.copy_context()
}) })
worker_thread.start() worker_thread.start()
@ -150,8 +154,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
node_id: str, node_id: str,
user: Account, user: Account,
args: dict, args: dict,
stream: bool = True) \ stream: bool = True):
-> Union[dict, Generator[dict, None, None]]:
""" """
Generate App response. Generate App response.
@ -193,6 +196,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
inputs=args['inputs'] inputs=args['inputs']
) )
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
@ -205,7 +209,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
def _generate_worker(self, flask_app: Flask, def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None: queue_manager: AppQueueManager,
context: contextvars.Context) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -213,6 +218,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param queue_manager: queue manager :param queue_manager: queue manager
:return: :return:
""" """
for var, val in context.items():
var.set(val)
with flask_app.app_context(): with flask_app.app_context():
try: try:
# workflow app # workflow app

View File

@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db from extensions.ext_database import db
@ -56,7 +57,7 @@ class WorkflowAppRunner:
db.session.close() db.session.close()
workflow_callbacks = [WorkflowEventTriggerCallback( workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
queue_manager=queue_manager, queue_manager=queue_manager,
workflow=workflow workflow=workflow
)] )]

View File

@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
) )
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow from models.workflow import Workflow
class WorkflowEventTriggerCallback(BaseWorkflowCallback): class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager self._queue_manager = queue_manager

View File

@ -2,7 +2,7 @@ from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent from core.app.entities.queue_entities import AppQueueEvent
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
@ -15,7 +15,7 @@ _TEXT_COLOR_MAPPING = {
} }
class WorkflowLoggingCallback(BaseWorkflowCallback): class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None: def __init__(self) -> None:
self.current_node_id = None self.current_node_id = None

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
@ -76,7 +77,7 @@ class AppGenerateEntity(BaseModel):
# app config # app config
app_config: AppConfig app_config: AppConfig
inputs: dict[str, Any] inputs: Mapping[str, Any]
files: list[FileVar] = [] files: list[FileVar] = []
user_id: str user_id: str
@ -140,7 +141,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
query: Optional[str] = None query: str
class SingleIterationRunEntity(BaseModel): class SingleIterationRunEntity(BaseModel):
""" """

View File

@ -0,0 +1,27 @@
from .segment_group import SegmentGroup
from .segments import Segment
from .types import SegmentType
from .variables import (
ArrayVariable,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
)
__all__ = [
'IntegerVariable',
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'FileVariable',
'StringVariable',
'ArrayVariable',
'Variable',
'SegmentType',
'SegmentGroup',
'Segment'
]

View File

@ -0,0 +1,64 @@
from collections.abc import Mapping
from typing import Any
from core.file.file_obj import FileVar
from .segments import Segment, StringSegment
from .types import SegmentType
from .variables import (
ArrayVariable,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
)
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
if (value_type := m.get('value_type')) is None:
raise ValueError('missing value type')
if not m.get('name'):
raise ValueError('missing name')
if (value := m.get('value')) is None:
raise ValueError('missing value')
match value_type:
case SegmentType.STRING:
return StringVariable.model_validate(m)
case SegmentType.NUMBER if isinstance(value, int):
return IntegerVariable.model_validate(m)
case SegmentType.NUMBER if isinstance(value, float):
return FloatVariable.model_validate(m)
case SegmentType.SECRET:
return SecretVariable.model_validate(m)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise ValueError(f'invalid number value {value}')
raise ValueError(f'not supported value type {value_type}')
def build_anonymous_variable(value: Any, /) -> Variable:
if isinstance(value, str):
return StringVariable(name='anonymous', value=value)
if isinstance(value, int):
return IntegerVariable(name='anonymous', value=value)
if isinstance(value, float):
return FloatVariable(name='anonymous', value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
obj = {k: build_anonymous_variable(v) for k, v in value.items()}
return ObjectVariable(name='anonymous', value=obj)
if isinstance(value, list):
# TODO: Limit the depth of the array
elements = [build_anonymous_variable(v) for v in value]
return ArrayVariable(name='anonymous', value=elements)
if isinstance(value, FileVar):
return FileVariable(name='anonymous', value=value)
raise ValueError(f'not supported value {value}')
def build_segment(value: Any, /) -> Segment:
if isinstance(value, str):
return StringSegment(value=value)
raise ValueError(f'not supported value {value}')

View File

@ -0,0 +1,17 @@
import re
from core.app.segments import SegmentGroup, factory
from core.workflow.entities.variable_pool import VariablePool
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
def convert_template(*, template: str, variable_pool: VariablePool):
parts = re.split(VARIABLE_PATTERN, template)
segments = []
for part in parts:
if '.' in part and (value := variable_pool.get(part.split('.'))):
segments.append(value)
else:
segments.append(factory.build_segment(part))
return SegmentGroup(segments=segments)

View File

@ -0,0 +1,19 @@
from pydantic import BaseModel
from .segments import Segment
class SegmentGroup(BaseModel):
segments: list[Segment]
@property
def text(self):
return ''.join([segment.text for segment in self.segments])
@property
def log(self):
return ''.join([segment.log for segment in self.segments])
@property
def markdown(self):
return ''.join([segment.markdown for segment in self.segments])

View File

@ -0,0 +1,39 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from .types import SegmentType
class Segment(BaseModel):
model_config = ConfigDict(frozen=True)
value_type: SegmentType
value: Any
@field_validator('value_type')
def validate_value_type(cls, value):
"""
This validator checks if the provided value is equal to the default value of the 'value_type' field.
If the value is different, a ValueError is raised.
"""
if value != cls.model_fields['value_type'].default:
raise ValueError("Cannot modify 'value_type'")
return value
@property
def text(self) -> str:
return str(self.value)
@property
def log(self) -> str:
return str(self.value)
@property
def markdown(self) -> str:
return str(self.value)
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str

View File

@ -0,0 +1,17 @@
from enum import Enum
class SegmentType(str, Enum):
STRING = 'string'
NUMBER = 'number'
FILE = 'file'
SECRET = 'secret'
OBJECT = 'object'
ARRAY = 'array'
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
ARRAY_FILE = 'array[file]'

View File

@ -0,0 +1,83 @@
import json
from collections.abc import Mapping, Sequence
from pydantic import Field
from core.file.file_obj import FileVar
from core.helper import encrypter
from .segments import Segment, StringSegment
from .types import SegmentType
class Variable(Segment):
"""
A variable is a segment that has a name.
"""
id: str = Field(
default='',
description="Unique identity for variable. It's only used by environment variables now.",
)
name: str
class StringVariable(StringSegment, Variable):
pass
class FloatVariable(Variable):
value_type: SegmentType = SegmentType.NUMBER
value: float
class IntegerVariable(Variable):
value_type: SegmentType = SegmentType.NUMBER
value: int
class ObjectVariable(Variable):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Variable]
@property
def text(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
@property
def log(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
# TODO: Use markdown code block
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
class ArrayVariable(Variable):
value_type: SegmentType = SegmentType.ARRAY
value: Sequence[Variable]
@property
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
class FileVariable(Variable):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()
class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET
@property
def log(self) -> str:
return encrypter.obfuscated_token(self.value)

View File

@ -1,9 +1,11 @@
import os import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional, TextIO, Union from typing import Any, Optional, TextIO, Union
from pydantic import BaseModel from pydantic import BaseModel
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
from core.tools.entities.tool_entities import ToolInvokeMessage
_TEXT_COLOR_MAPPING = { _TEXT_COLOR_MAPPING = {
"blue": "36;1", "blue": "36;1",
@ -43,7 +45,7 @@ class DifyAgentCallbackHandler(BaseModel):
def on_tool_start( def on_tool_start(
self, self,
tool_name: str, tool_name: str,
tool_inputs: dict[str, Any], tool_inputs: Mapping[str, Any],
) -> None: ) -> None:
"""Do nothing.""" """Do nothing."""
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
@ -51,8 +53,8 @@ class DifyAgentCallbackHandler(BaseModel):
def on_tool_end( def on_tool_end(
self, self,
tool_name: str, tool_name: str,
tool_inputs: dict[str, Any], tool_inputs: Mapping[str, Any],
tool_outputs: str, tool_outputs: Sequence[ToolInvokeMessage],
message_id: Optional[str] = None, message_id: Optional[str] = None,
timer: Optional[Any] = None, timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None

View File

@ -1,4 +1,5 @@
from typing import Union from collections.abc import Mapping, Sequence
from typing import Any, Union
import requests import requests
@ -16,7 +17,7 @@ class MessageFileParser:
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.app_id = app_id self.app_id = app_id
def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig, def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
user: Union[Account, EndUser]) -> list[FileVar]: user: Union[Account, EndUser]) -> list[FileVar]:
""" """
validate and transform files arg validate and transform files arg

View File

@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
CODE_EXECUTION_TIMEOUT= (10, 60) CODE_EXECUTION_TIMEOUT = (10, 60)
class CodeExecutionException(Exception): class CodeExecutionException(Exception):
pass pass
@ -64,7 +64,7 @@ class CodeExecutor:
@classmethod @classmethod
def execute_code(cls, def execute_code(cls,
language: Literal['python3', 'javascript', 'jinja2'], language: CodeLanguage,
preload: str, preload: str,
code: str, code: str,
dependencies: Optional[list[CodeDependency]] = None) -> str: dependencies: Optional[list[CodeDependency]] = None) -> str:
@ -119,7 +119,7 @@ class CodeExecutor:
return response.data.stdout return response.data.stdout
@classmethod @classmethod
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
""" """
Execute code Execute code
:param language: code language :param language: code language

View File

@ -6,11 +6,16 @@ from models.account import Tenant
def obfuscated_token(token: str): def obfuscated_token(token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:] if not token:
return token
if len(token) <= 8:
return '*' * 20
return token[:6] + '*' * 12 + token[-2:]
def encrypt_token(tenant_id: str, token: str): def encrypt_token(tenant_id: str, token: str):
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f'Tenant with id {tenant_id} not found')
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode() return base64.b64encode(encrypted_token).decode()

View File

@ -14,6 +14,9 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
:return: a dict with name as key and index as value :return: a dict with name as key and index as value
""" """
position_file_name = os.path.join(folder_path, file_name) position_file_name = os.path.join(folder_path, file_name)
if not position_file_name or not os.path.exists(position_file_name):
return {}
positions = load_yaml_file(position_file_name, ignore_error=True) positions = load_yaml_file(position_file_name, ignore_error=True)
position_map = {} position_map = {}
index = 0 index = 0

View File

@ -64,6 +64,7 @@ User Input:
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, " "Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n" "and keeping each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
"The output must be an array in JSON format following the specified schema:\n" "The output must be an array in JSON format following the specified schema:\n"
"[\"question1\",\"question2\",\"question3\"]\n" "[\"question1\",\"question2\",\"question3\"]\n"
) )

View File

@ -103,7 +103,7 @@ class TokenBufferMemory:
if curr_message_tokens > max_token_limit: if curr_message_tokens > max_token_limit:
pruned_memory = [] pruned_memory = []
while curr_message_tokens > max_token_limit and prompt_messages: while curr_message_tokens > max_token_limit and len(prompt_messages)>1:
pruned_memory.append(prompt_messages.pop(0)) pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = self.model_instance.get_llm_num_tokens( curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages prompt_messages

View File

@ -413,6 +413,7 @@ class LBModelManager:
for load_balancing_config in self._load_balancing_configs: for load_balancing_config in self._load_balancing_configs:
if load_balancing_config.name == "__inherit__": if load_balancing_config.name == "__inherit__":
if not managed_credentials: if not managed_credentials:
# FIXME: Mutation to loop iterable `self._load_balancing_configs` during iteration
# remove __inherit__ if managed credentials is not provided # remove __inherit__ if managed credentials is not provided
self._load_balancing_configs.remove(load_balancing_config) self._load_balancing_configs.remove(load_balancing_config)
else: else:

View File

@ -27,9 +27,9 @@ parameter_rules:
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
required: true required: true
default: 4096 default: 8192
min: 1 min: 1
max: 4096 max: 8192
- name: response_format - name: response_format
use_template: response_format use_template: response_format
pricing: pricing:

View File

@ -113,6 +113,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if system: if system:
extra_model_kwargs['system'] = system extra_model_kwargs['system'] = system
# Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {}
if model == "claude-3-5-sonnet-20240620":
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
if tools: if tools:
extra_model_kwargs['tools'] = [ extra_model_kwargs['tools'] = [
self._transform_tool_prompt(tool) for tool in tools self._transform_tool_prompt(tool) for tool in tools
@ -121,6 +126,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model=model, model=model,
messages=prompt_message_dicts, messages=prompt_message_dicts,
stream=stream, stream=stream,
extra_headers=extra_headers,
**model_parameters, **model_parameters,
**extra_model_kwargs **extra_model_kwargs
) )
@ -130,6 +136,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model=model, model=model,
messages=prompt_message_dicts, messages=prompt_message_dicts,
stream=stream, stream=stream,
extra_headers=extra_headers,
**model_parameters, **model_parameters,
**extra_model_kwargs **extra_model_kwargs
) )
@ -138,7 +145,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages) return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,

View File

@ -501,7 +501,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages} message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) # message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls: if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls] message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]

View File

@ -1,6 +1,8 @@
- gpt-4 - gpt-4
- gpt-4o - gpt-4o
- gpt-4o-2024-05-13 - gpt-4o-2024-05-13
- gpt-4o-mini
- gpt-4o-mini-2024-07-18
- gpt-4-turbo - gpt-4-turbo
- gpt-4-turbo-2024-04-09 - gpt-4-turbo-2024-04-09
- gpt-4-turbo-preview - gpt-4-turbo-preview

View File

@ -0,0 +1,44 @@
model: gpt-4o-mini-2024-07-18
label:
zh_Hans: gpt-4o-mini-2024-07-18
en_US: gpt-4o-mini-2024-07-18
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.15'
output: '0.60'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,44 @@
model: gpt-4o-mini
label:
zh_Hans: gpt-4o-mini
en_US: gpt-4o-mini
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.15'
output: '0.60'
unit: '0.000001'
currency: USD

View File

@ -1,4 +1,5 @@
- openai/gpt-4o - openai/gpt-4o
- openai/gpt-4o-mini
- openai/gpt-4 - openai/gpt-4
- openai/gpt-4-32k - openai/gpt-4-32k
- openai/gpt-3.5-turbo - openai/gpt-3.5-turbo

View File

@ -0,0 +1,43 @@
model: openai/gpt-4o-mini
label:
en_US: gpt-4o-mini
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: "0.15"
output: "0.60"
unit: "0.000001"
currency: USD

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.5 KiB

View File

@ -0,0 +1,238 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, Union
import boto3
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
Model class for Cohere large language model.
"""
sagemaker_client: Any = None
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# get model mode
model_mode = self.get_model_mode(model, credentials)
if not self.sagemaker_client:
access_key = credentials.get('access_key')
secret_key = credentials.get('secret_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(
{
"inputs": prompt_messages[0].content,
"parameters": { "stop" : stop},
"history" : []
}
),
ContentType="application/json",
)
assistant_text = response_model['Body'].read().decode('utf8')
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
usage = self._calc_response_usage(model, credentials, 0, 0)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
)
return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
# get model mode
model_mode = self.get_model_mode(model)
try:
return 0
except Exception as e:
raise self._transform_invoke_error(e)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# get model mode
model_mode = self.get_model_mode(model)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature',
type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
),
),
ParameterRule(
name='top_p',
type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P',
en_US='Top P'
)
),
ParameterRule(
name='max_tokens',
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=credentials.get('context_length', 2048),
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
en_US='Max Tokens'
)
)
]
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type == LLMMode.CHAT:
print(f"completion_type : {LLMMode.CHAT.value}")
if completion_type == LLMMode.COMPLETION:
print(f"completion_type : {LLMMode.COMPLETION.value}")
features = []
support_function_call = credentials.get('support_function_call', False)
if support_function_call:
features.append(ModelFeature.TOOL_CALL)
support_vision = credentials.get('support_vision', False)
if support_vision:
features.append(ModelFeature.VISION)
context_length = credentials.get('context_length', 2048)
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=features,
model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
},
parameter_rules=rules
)
return entity

View File

@ -0,0 +1,190 @@
import json
import logging
from typing import Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
logger = logging.getLogger(__name__)
class SageMakerRerankModel(RerankModel):
"""
Model class for Cohere rerank model.
"""
sagemaker_client: Any = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
inputs = [query_input]*len(docs)
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=rerank_endpoint,
Body=json.dumps(
{
"inputs": inputs,
"docs": docs
}
),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
scores = json_obj['scores']
return scores if isinstance(scores, list) else [scores]
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
line = 0
try:
if len(docs) == 0:
return RerankResult(
model=model,
docs=docs
)
line = 1
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
line = 2
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
candidate_docs = []
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
for idx in range(len(scores)):
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
line = 3
rerank_documents = []
for idx, result in enumerate(candidate_docs):
rerank_document = RerankDocument(
index=idx,
text=result.get('content'),
score=result.get('score', -100.0)
)
if score_threshold is not None:
if rerank_document.score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return RerankResult(
model=model,
docs=rerank_documents
)
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={ },
parameter_rules=[]
)
return entity

View File

@ -0,0 +1,17 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class SageMakerProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass

View File

@ -0,0 +1,125 @@
provider: sagemaker
label:
zh_Hans: Sagemaker
en_US: Sagemaker
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
description:
en_US: Customized model on Sagemaker
zh_Hans: Sagemaker上的私有化部署的模型
background: "#ECE9E3"
help:
title:
en_US: How to deploy customized model on Sagemaker
zh_Hans: 如何在Sagemaker上的私有化部署的模型
url:
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: sagemaker_endpoint
label:
en_US: sagemaker endpoint
type: text-input
required: true
placeholder:
zh_Hans: 请输出你的Sagemaker推理端点
en_US: Enter your Sagemaker Inference endpoint
- variable: aws_access_key_id
required: false
label:
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
type: secret-input
placeholder:
en_US: Enter your Access Key
zh_Hans: 在此输入您的 Access Key
- variable: aws_secret_access_key
required: false
label:
en_US: Secret Access Key
zh_Hans: Secret Access Key
type: secret-input
placeholder:
en_US: Enter your Secret Access Key
zh_Hans: 在此输入您的 Secret Access Key
- variable: aws_region
required: false
label:
en_US: AWS Region
zh_Hans: AWS 地区
type: select
default: us-east-1
options:
- value: us-east-1
label:
en_US: US East (N. Virginia)
zh_Hans: 美国东部 (弗吉尼亚北部)
- value: us-west-2
label:
en_US: US West (Oregon)
zh_Hans: 美国西部 (俄勒冈州)
- value: ap-southeast-1
label:
en_US: Asia Pacific (Singapore)
zh_Hans: 亚太地区 (新加坡)
- value: ap-northeast-1
label:
en_US: Asia Pacific (Tokyo)
zh_Hans: 亚太地区 (东京)
- value: eu-central-1
label:
en_US: Europe (Frankfurt)
zh_Hans: 欧洲 (法兰克福)
- value: us-gov-west-1
label:
en_US: AWS GovCloud (US-West)
zh_Hans: AWS GovCloud (US-West)
- value: ap-southeast-2
label:
en_US: Asia Pacific (Sydney)
zh_Hans: 亚太地区 (悉尼)
- value: cn-north-1
label:
en_US: AWS Beijing (cn-north-1)
zh_Hans: 中国北京 (cn-north-1)
- value: cn-northwest-1
label:
en_US: AWS Ningxia (cn-northwest-1)
zh_Hans: 中国宁夏 (cn-northwest-1)

View File

@ -0,0 +1,214 @@
import itertools
import json
import logging
import time
from typing import Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
BATCH_SIZE = 20
CONTEXT_SIZE=8192
logger = logging.getLogger(__name__)
def batch_generator(generator, batch_size):
while True:
batch = list(itertools.islice(generator, batch_size))
if not batch:
break
yield batch
class SageMakerEmbeddingModel(TextEmbeddingModel):
"""
Model class for Cohere text embedding model.
"""
sagemaker_client: Any = None
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
response_model = sm_client.invoke_endpoint(
EndpointName=endpoint_name,
Body=json.dumps(
{
"inputs": content_list,
"parameters": {},
"is_query" : False,
"instruction" : ''
}
),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
embeddings = json_obj['embeddings']
return embeddings
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# get model properties
try:
line = 1
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
line = 2
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
line = 3
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
all_embeddings = []
line = 4
for batch in batches:
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
all_embeddings.extend(embeddings)
line = 5
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=0 # It's not SAAS API, usage is meaningless
)
line = 6
return TextEmbeddingResult(
embeddings=all_embeddings,
usage=usage,
model=model
)
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return 0
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
print("validate_credentials ok....")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
},
parameter_rules=[]
)
return entity

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@ -0,0 +1,6 @@
- step-1-8k
- step-1-32k
- step-1-128k
- step-1-256k
- step-1v-8k
- step-1v-32k

View File

@ -0,0 +1,328 @@
import json
from collections.abc import Generator
from typing import Optional, Union, cast
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get('function_calling_type') == 'tool_call'
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name='temperature',
use_template='temperature',
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
type=ParameterType.FLOAT,
),
ParameterRule(
name='max_tokens',
use_template='max_tokens',
default=512,
min=1,
max=int(credentials.get('max_tokens', 1024)),
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
type=ParameterType.INT,
),
ParameterRule(
name='top_p',
use_template='top_p',
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
type=ParameterType.FLOAT,
),
]
)
def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.stepfun.com/v1'
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)
if model_schema and {
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
}.intersection(model_schema.features or []):
credentials['function_calling_type'] = 'tool_call'
def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict:
"""
Convert PromptMessage to dict for OpenAI API format
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = []
for function_call in message.tool_calls:
message_dict["tool_calls"].append({
"id": function_call.id,
"type": function_call.type,
"function": {
"name": function_call.function.name,
"arguments": function_call.function.arguments
}
})
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call["id"] if response_tool_call.get("id") else "",
type=response_tool_call["type"] if response_tool_call.get("type") else "",
function=function
)
tool_calls.append(tool_call)
return tool_calls
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ''
chunk_index = 0
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
-> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=message,
finish_reason=finish_reason,
usage=usage
)
)
tools_calls: list[AssistantPromptMessage.ToolCall] = []
finish_reason = "Unknown"
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id='',
type='',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
if chunk:
# ignore sse comments
if chunk.startswith(':'):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
break
if not chunk_json or len(chunk_json['choices']) == 0:
continue
choice = chunk_json['choices'][0]
finish_reason = chunk_json['choices'][0].get('finish_reason')
chunk_index += 1
if 'delta' in choice:
delta = choice['delta']
delta_content = delta.get('content')
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta_content
elif 'text' in choice:
choice_text = choice.get('text', '')
if choice_text == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
# check payload indicator for completion
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
)
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(
tool_calls=tools_calls,
content=""
),
)
)
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason
)

View File

@ -0,0 +1,25 @@
model: step-1-128k
label:
zh_Hans: step-1-128k
en_US: step-1-128k
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 128000
pricing:
input: '0.04'
output: '0.20'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,25 @@
model: step-1-256k
label:
zh_Hans: step-1-256k
en_US: step-1-256k
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 256000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 256000
pricing:
input: '0.095'
output: '0.300'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,28 @@
model: step-1-32k
label:
zh_Hans: step-1-32k
en_US: step-1-32k
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
pricing:
input: '0.015'
output: '0.070'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,28 @@
model: step-1-8k
label:
zh_Hans: step-1-8k
en_US: step-1-8k
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 8000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8000
pricing:
input: '0.005'
output: '0.020'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,25 @@
model: step-1v-32k
label:
zh_Hans: step-1v-32k
en_US: step-1v-32k
model_type: llm
features:
- vision
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
pricing:
input: '0.015'
output: '0.070'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,25 @@
model: step-1v-8k
label:
zh_Hans: step-1v-8k
en_US: step-1v-8k
model_type: llm
features:
- vision
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.005'
output: '0.020'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,30 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class StepfunProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='step-1-8k',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -0,0 +1,81 @@
provider: stepfun
label:
zh_Hans: 阶跃星辰
en_US: Stepfun
description:
en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k
zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
background: "#FFFFFF"
help:
title:
en_US: Get your API Key from stepfun
zh_Hans: 从 stepfun 获取 API Key
url:
en_US: https://platform.stepfun.com/interface-key
supported_model_types:
- llm
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '8192'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '8192'
type: text-input
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not supported
zh_Hans: 不支持
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call

View File

@ -35,3 +35,4 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。 zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search. en_US: Disable the model to perform external search.
required: false required: false
deprecated: true

View File

@ -1,4 +1,4 @@
model: ernie-4.0-8k-Latest model: ernie-4.0-8k-latest
label: label:
en_US: Ernie-4.0-8K-Latest en_US: Ernie-4.0-8K-Latest
model_type: llm model_type: llm

View File

@ -0,0 +1,40 @@
model: ernie-4.0-turbo-8k
label:
en_US: Ernie-4.0-turbo-8K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 2048
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
en_US: Disable Search
type: boolean
help:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false

View File

@ -28,3 +28,4 @@ parameter_rules:
default: 1.0 default: 1.0
min: 1.0 min: 1.0
max: 2.0 max: 2.0
deprecated: true

View File

@ -0,0 +1,30 @@
model: ernie-character-8k-0321
label:
en_US: ERNIE-Character-8K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.95
- name: top_p
use_template: top_p
min: 0
max: 1.0
default: 0.7
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 1024
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0

View File

@ -28,3 +28,4 @@ parameter_rules:
default: 1.0 default: 1.0
min: 1.0 min: 1.0
max: 2.0 max: 2.0
deprecated: true

View File

@ -28,3 +28,4 @@ parameter_rules:
default: 1.0 default: 1.0
min: 1.0 min: 1.0
max: 2.0 max: 2.0
deprecated: true

View File

@ -97,6 +97,7 @@ class BaiduAccessToken:
baidu_access_tokens_lock.release() baidu_access_tokens_lock.release()
return token return token
class ErnieMessage: class ErnieMessage:
class Role(Enum): class Role(Enum):
USER = 'user' USER = 'user'
@ -137,7 +138,9 @@ class ErnieBotModel:
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', 'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
} }
@ -149,7 +152,8 @@ class ErnieBotModel:
'ernie-3.5-8k-1222', 'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205', 'ernie-3.5-4k-0205',
'ernie-3.5-128k', 'ernie-3.5-128k',
'ernie-4.0-8k' 'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview' 'ernie-4.0-turbo-8k-preview'
] ]

Some files were not shown because too many files have changed in this diff Show More