mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 00:45:53 +08:00
fix bug
This commit is contained in:
commit
a603e01f5e
1
.gitignore
vendored
1
.gitignore
vendored
@ -174,5 +174,6 @@ sdks/python-client/dify_client.egg-info
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
|
@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
||||
|
||||
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
||||
|
||||
Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot.
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
|
||||
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
|
@ -77,7 +77,7 @@ Dify 依赖以下工具和库:
|
||||
|
||||
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
|
||||
|
||||
查看 [安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq) 以获取常见问题列表和故障排除步骤。
|
||||
查看 [安装常见问题解答](https://docs.dify.ai/v/zh-hans/learn-more/faq/install-faq) 以获取常见问题列表和故障排除步骤。
|
||||
|
||||
### 5. 在浏览器中访问 Dify
|
||||
|
||||
|
@ -82,7 +82,7 @@ Dify はバックエンドとフロントエンドから構成されています
|
||||
まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
|
||||
次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
|
||||
|
||||
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。
|
||||
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/v/japanese/learn-more/faq/install-faq) を確認してください。
|
||||
|
||||
### 5. ブラウザで dify にアクセスする
|
||||
|
||||
|
@ -256,3 +256,7 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
@ -1,7 +1,5 @@
|
||||
import os
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != 'true':
|
||||
from gevent import monkey
|
||||
|
||||
@ -23,7 +21,9 @@ from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import contexts
|
||||
from commands import register_commands
|
||||
from configs import dify_config
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
@ -181,7 +181,10 @@ def load_user_from_request(request_from_flask_login):
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get('user_id')
|
||||
|
||||
return AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
if account:
|
||||
contexts.tenant_id.set(account.current_tenant_id)
|
||||
return account
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
|
@ -23,6 +23,7 @@ class SecurityConfig(BaseSettings):
|
||||
default=24,
|
||||
)
|
||||
|
||||
|
||||
class AppExecutionConfig(BaseSettings):
|
||||
"""
|
||||
App Execution configs
|
||||
@ -405,7 +406,6 @@ class DataSetConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
Workspace configs
|
||||
@ -435,6 +435,13 @@ class ImageFormatConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description='the time of the celery scheduler, default to 1 day',
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -462,5 +469,6 @@ class FeatureConfig(
|
||||
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
):
|
||||
pass
|
||||
|
@ -79,7 +79,7 @@ class HostedAzureOpenAiConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
|
||||
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
@ -8,32 +7,32 @@ class MyScaleConfig(BaseModel):
|
||||
MyScale configs
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: Optional[str] = Field(
|
||||
MYSCALE_HOST: str = Field(
|
||||
description='MyScale host',
|
||||
default=None,
|
||||
default='localhost',
|
||||
)
|
||||
|
||||
MYSCALE_PORT: Optional[PositiveInt] = Field(
|
||||
MYSCALE_PORT: PositiveInt = Field(
|
||||
description='MyScale port',
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: Optional[str] = Field(
|
||||
MYSCALE_USER: str = Field(
|
||||
description='MyScale user',
|
||||
default=None,
|
||||
default='default',
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: Optional[str] = Field(
|
||||
MYSCALE_PASSWORD: str = Field(
|
||||
description='MyScale password',
|
||||
default=None,
|
||||
default='',
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: Optional[str] = Field(
|
||||
MYSCALE_DATABASE: str = Field(
|
||||
description='MyScale database name',
|
||||
default=None,
|
||||
default='default',
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: Optional[str] = Field(
|
||||
MYSCALE_FTS_PARAMS: str = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default=None,
|
||||
default='',
|
||||
)
|
||||
|
@ -0,0 +1,2 @@
|
||||
# TODO: Update all string in code to use this constant
|
||||
HIDDEN_VALUE = '[__HIDDEN__]'
|
3
api/contexts/__init__.py
Normal file
3
api/contexts/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
@ -212,7 +212,7 @@ class AppCopyApi(Resource):
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
data = AppDslService.export_dsl(app_model=app_model)
|
||||
data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=data,
|
||||
@ -234,8 +234,13 @@ class AppExportApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(app_model=app_model)
|
||||
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
|
||||
}
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import factory
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
@ -41,7 +42,7 @@ class DraftWorkflowApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
@ -64,13 +65,15 @@ class DraftWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get('Content-Type')
|
||||
content_type = request.headers.get('Content-Type', '')
|
||||
|
||||
if 'application/json' in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('hash', type=str, required=False, location='json')
|
||||
# TODO: set this to required=True after frontend is updated
|
||||
parser.add_argument('environment_variables', type=list, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
elif 'text/plain' in content_type:
|
||||
try:
|
||||
@ -84,7 +87,8 @@ class DraftWorkflowApi(Resource):
|
||||
args = {
|
||||
'graph': data.get('graph'),
|
||||
'features': data.get('features'),
|
||||
'hash': data.get('hash')
|
||||
'hash': data.get('hash'),
|
||||
'environment_variables': data.get('environment_variables')
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {'message': 'Invalid JSON data'}, 400
|
||||
@ -94,12 +98,15 @@ class DraftWorkflowApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get('environment_variables') or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args.get('graph'),
|
||||
features=args.get('features'),
|
||||
graph=args['graph'],
|
||||
features=args['features'],
|
||||
unique_hash=args.get('hash'),
|
||||
account=current_user
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
@ -1,10 +1,11 @@
|
||||
import flask_restful
|
||||
from flask import current_app, request
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.apikey import api_key_fields, api_key_list
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -530,7 +531,7 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
|
||||
'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
|
||||
else request.host_url.rstrip('/')) + '/v1'
|
||||
}
|
||||
|
||||
@ -540,20 +541,20 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = current_app.config['VECTOR_STORE']
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
||||
RetrievalMethod.HYBRID_SEARCH,
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
@ -569,15 +570,15 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
||||
RetrievalMethod.HYBRID_SEARCH,
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
|
@ -75,7 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
)
|
||||
|
||||
if last_id is not None:
|
||||
last_segment = DocumentSegment.query.get(str(last_id))
|
||||
last_segment = db.session.get(DocumentSegment, str(last_id))
|
||||
if last_segment:
|
||||
query = query.filter(
|
||||
DocumentSegment.position > last_segment.position)
|
||||
|
@ -1,8 +1,9 @@
|
||||
from flask import current_app, request
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import (
|
||||
FileTooLargeError,
|
||||
@ -26,9 +27,9 @@ class FileApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(upload_config_fields)
|
||||
def get(self):
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
|
||||
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
|
||||
image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
|
||||
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT
|
||||
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
||||
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
return {
|
||||
'file_size_limit': file_size_limit,
|
||||
'batch_count_limit': batch_count_limit,
|
||||
@ -76,7 +77,7 @@ class FileSupportTypeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||
return {'allowed_extensions': allowed_extensions}
|
||||
|
||||
|
@ -78,10 +78,12 @@ class ChatTextApi(InstalledAppResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
@ -95,7 +97,8 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
voice=voice,
|
||||
text=text
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
@ -1,7 +1,7 @@
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@ -78,7 +78,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
|
||||
from flask import current_app, session
|
||||
from flask import session
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import str_len
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
@ -40,7 +41,7 @@ class InitValidateAPI(Resource):
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
def get_init_validate_status():
|
||||
if current_app.config['EDITION'] == 'SELF_HOSTED':
|
||||
if dify_config.EDITION == 'SELF_HOSTED':
|
||||
if os.environ.get('INIT_PASSWORD'):
|
||||
return session.get('is_init_validated') or DifySetup.query.first()
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
from functools import wraps
|
||||
|
||||
from flask import current_app, request
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import email, get_remote_ip, str_len
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
@ -17,7 +18,7 @@ from .wraps import only_edition_self_hosted
|
||||
class SetupApi(Resource):
|
||||
|
||||
def get(self):
|
||||
if current_app.config['EDITION'] == 'SELF_HOSTED':
|
||||
if dify_config.EDITION == 'SELF_HOSTED':
|
||||
setup_status = get_setup_status()
|
||||
if setup_status:
|
||||
return {
|
||||
@ -77,7 +78,7 @@ def setup_required(view):
|
||||
|
||||
|
||||
def get_setup_status():
|
||||
if current_app.config['EDITION'] == 'SELF_HOSTED':
|
||||
if dify_config.EDITION == 'SELF_HOSTED':
|
||||
return DifySetup.query.first()
|
||||
else:
|
||||
return True
|
||||
|
@ -3,9 +3,10 @@ import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from . import api
|
||||
|
||||
|
||||
@ -15,16 +16,16 @@ class VersionApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('current_version', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
check_update_url = current_app.config['CHECK_UPDATE_URL']
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = {
|
||||
'version': current_app.config['CURRENT_VERSION'],
|
||||
'version': dify_config.CURRENT_VERSION,
|
||||
'release_date': '',
|
||||
'release_notes': '',
|
||||
'can_auto_update': False,
|
||||
'features': {
|
||||
'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'],
|
||||
'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED']
|
||||
'can_replace_logo': dify_config.CAN_REPLACE_LOGO,
|
||||
'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
import datetime
|
||||
|
||||
import pytz
|
||||
from flask import current_app, request
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
@ -36,7 +37,7 @@ class AccountInitApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
if dify_config.EDITION == 'CLOUD':
|
||||
parser.add_argument('invitation_code', type=str, location='json')
|
||||
|
||||
parser.add_argument(
|
||||
@ -45,7 +46,7 @@ class AccountInitApi(Resource):
|
||||
required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
if dify_config.EDITION == 'CLOUD':
|
||||
if not args['invitation_code']:
|
||||
raise ValueError('invitation_code is required')
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
@ -48,7 +48,7 @@ class MemberInviteEmailApi(Resource):
|
||||
|
||||
inviter = current_user
|
||||
invitation_results = []
|
||||
console_web_url = current_app.config.get("CONSOLE_WEB_URL")
|
||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter)
|
||||
@ -117,7 +117,7 @@ class MemberUpdateRoleApi(Resource):
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
member = Account.query.get(str(member_id))
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if not member:
|
||||
abort(404)
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
import io
|
||||
|
||||
from flask import current_app, send_file
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@ -104,7 +105,7 @@ class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
icon_cache_max_age = current_app.config.get('TOOL_ICON_CACHE_MAX_AGE')
|
||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
|
@ -1,9 +1,10 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from flask import abort, current_app, request
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from services.feature_service import FeatureService
|
||||
from services.operation_service import OperationService
|
||||
@ -26,7 +27,7 @@ def account_initialization_required(view):
|
||||
def only_edition_cloud(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if current_app.config['EDITION'] != 'CLOUD':
|
||||
if dify_config.EDITION != 'CLOUD':
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
@ -37,7 +38,7 @@ def only_edition_cloud(view):
|
||||
def only_edition_self_hosted(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if current_app.config['EDITION'] != 'SELF_HOSTED':
|
||||
if dify_config.EDITION != 'SELF_HOSTED':
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
@ -76,10 +76,12 @@ class TextApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
@ -87,15 +89,15 @@ class TextApi(Resource):
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=voice
|
||||
voice=voice,
|
||||
text=text
|
||||
)
|
||||
|
||||
return response
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@ -21,14 +21,43 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunApi(Resource):
|
||||
workflow_run_fields = {
|
||||
'id': fields.String,
|
||||
'workflow_id': fields.String,
|
||||
'status': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'outputs': fields.Raw,
|
||||
'error': fields.String,
|
||||
'total_steps': fields.Integer,
|
||||
'total_tokens': fields.Integer,
|
||||
'created_at': fields.DateTime,
|
||||
'finished_at': fields.DateTime,
|
||||
'elapsed_time': fields.Float,
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(workflow_run_fields)
|
||||
def get(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Get a workflow task running detail
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
|
||||
return workflow_run
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""
|
||||
@ -88,5 +117,5 @@ class WorkflowTaskStopApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, '/workflows/run')
|
||||
api.add_resource(WorkflowRunApi, '/workflows/run/<string:workflow_id>', '/workflows/run')
|
||||
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')
|
||||
|
@ -74,10 +74,12 @@ class TextApi(WebApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
@ -94,7 +96,8 @@ class TextApi(WebApiResource):
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=voice
|
||||
voice=voice,
|
||||
text=text
|
||||
)
|
||||
|
||||
return response
|
||||
|
@ -342,10 +342,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
json.loads(prompt_message.function.arguments),
|
||||
args,
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
@ -359,10 +363,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
json.loads(prompt_message.function.arguments),
|
||||
args,
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.app_config.entities import AppAdditionalFeatures
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||
@ -10,37 +11,19 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class BaseAppConfigManager:
|
||||
|
||||
@classmethod
|
||||
def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom,
|
||||
app_model_config: Union[AppModelConfig, dict],
|
||||
config_dict: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert app model config to config dict
|
||||
:param config_from: app model config from
|
||||
:param app_model_config: app model config
|
||||
:param config_dict: app model config dict
|
||||
:return:
|
||||
"""
|
||||
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
|
||||
return config_dict
|
||||
|
||||
@classmethod
|
||||
def convert_features(cls, config_dict: dict, app_mode: AppMode) -> AppAdditionalFeatures:
|
||||
def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures:
|
||||
"""
|
||||
Convert app config to app model config
|
||||
|
||||
:param config_dict: app config
|
||||
:param app_mode: app mode
|
||||
"""
|
||||
config_dict = config_dict.copy()
|
||||
config_dict = dict(config_dict.items())
|
||||
|
||||
additional_features = AppAdditionalFeatures()
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict, is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
@ -3,13 +3,13 @@ from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
text_to_speech = False
|
||||
text_to_speech = None
|
||||
text_to_speech_dict = config.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if text_to_speech_dict.get('enabled'):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@ -8,6 +9,7 @@ from typing import Union
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
@ -107,6 +109,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@ -173,6 +176,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=args['inputs']
|
||||
)
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@ -225,6 +229,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'user': user,
|
||||
'context': contextvars.copy_context()
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@ -249,7 +255,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str) -> None:
|
||||
message_id: str,
|
||||
user: Account,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@ -259,6 +267,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
runner = AdvancedChatAppRunner()
|
||||
|
@ -1,7 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
@ -86,7 +88,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
@ -160,7 +162,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
) -> bool:
|
||||
|
@ -1,9 +1,11 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -18,12 +20,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
@ -39,7 +42,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -53,8 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -83,8 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
@ -113,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def process(self):
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@ -136,8 +136,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> ChatbotAppBlockingResponse:
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@ -167,8 +166,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@ -15,44 +15,41 @@ class AppGenerateResponseConverter(ABC):
|
||||
@classmethod
|
||||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, None, None]
|
||||
], invoke_from: InvokeFrom) -> Union[
|
||||
dict,
|
||||
Generator[str, None, None]
|
||||
]:
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom):
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
if isinstance(response, cls._blocking_response_type):
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
def _generate():
|
||||
def _generate_full_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, cls._blocking_response_type):
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
def _generate():
|
||||
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
return _generate_simple_response()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict:
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict:
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -68,7 +65,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _get_simple_metadata(cls, metadata: dict) -> dict:
|
||||
def _get_simple_metadata(cls, metadata: dict[str, Any]):
|
||||
"""
|
||||
Get simple metadata.
|
||||
:param metadata: metadata
|
||||
|
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@ -8,6 +9,7 @@ from typing import Union
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
@ -38,7 +40,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -86,6 +88,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@ -126,7 +129,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context()
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@ -150,8 +154,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
stream: bool = True):
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -193,6 +196,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
inputs=args['inputs']
|
||||
)
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@ -205,7 +209,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager) -> None:
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@ -213,6 +218,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param queue_manager: queue manager
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
|
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
@ -56,7 +57,7 @@ class WorkflowAppRunner:
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(
|
||||
workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
|
@ -14,13 +14,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
@ -15,7 +15,7 @@ _TEXT_COLOR_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class WorkflowLoggingCallback(BaseWorkflowCallback):
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -76,7 +77,7 @@ class AppGenerateEntity(BaseModel):
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
|
||||
inputs: dict[str, Any]
|
||||
inputs: Mapping[str, Any]
|
||||
files: list[FileVar] = []
|
||||
user_id: str
|
||||
|
||||
@ -140,7 +141,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
query: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
27
api/core/app/segments/__init__.py
Normal file
27
api/core/app/segments/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import Segment
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'IntegerVariable',
|
||||
'FloatVariable',
|
||||
'ObjectVariable',
|
||||
'SecretVariable',
|
||||
'FileVariable',
|
||||
'StringVariable',
|
||||
'ArrayVariable',
|
||||
'Variable',
|
||||
'SegmentType',
|
||||
'SegmentGroup',
|
||||
'Segment'
|
||||
]
|
64
api/core/app/segments/factory.py
Normal file
64
api/core/app/segments/factory.py
Normal file
@ -0,0 +1,64 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .segments import Segment, StringSegment
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
|
||||
|
||||
def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
|
||||
if (value_type := m.get('value_type')) is None:
|
||||
raise ValueError('missing value type')
|
||||
if not m.get('name'):
|
||||
raise ValueError('missing name')
|
||||
if (value := m.get('value')) is None:
|
||||
raise ValueError('missing value')
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
return StringVariable.model_validate(m)
|
||||
case SegmentType.NUMBER if isinstance(value, int):
|
||||
return IntegerVariable.model_validate(m)
|
||||
case SegmentType.NUMBER if isinstance(value, float):
|
||||
return FloatVariable.model_validate(m)
|
||||
case SegmentType.SECRET:
|
||||
return SecretVariable.model_validate(m)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise ValueError(f'invalid number value {value}')
|
||||
raise ValueError(f'not supported value type {value_type}')
|
||||
|
||||
|
||||
def build_anonymous_variable(value: Any, /) -> Variable:
|
||||
if isinstance(value, str):
|
||||
return StringVariable(name='anonymous', value=value)
|
||||
if isinstance(value, int):
|
||||
return IntegerVariable(name='anonymous', value=value)
|
||||
if isinstance(value, float):
|
||||
return FloatVariable(name='anonymous', value=value)
|
||||
if isinstance(value, dict):
|
||||
# TODO: Limit the depth of the object
|
||||
obj = {k: build_anonymous_variable(v) for k, v in value.items()}
|
||||
return ObjectVariable(name='anonymous', value=obj)
|
||||
if isinstance(value, list):
|
||||
# TODO: Limit the depth of the array
|
||||
elements = [build_anonymous_variable(v) for v in value]
|
||||
return ArrayVariable(name='anonymous', value=elements)
|
||||
if isinstance(value, FileVar):
|
||||
return FileVariable(name='anonymous', value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if isinstance(value, str):
|
||||
return StringSegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
17
api/core/app/segments/parser.py
Normal file
17
api/core/app/segments/parser.py
Normal file
@ -0,0 +1,17 @@
|
||||
import re
|
||||
|
||||
from core.app.segments import SegmentGroup, factory
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
|
||||
|
||||
|
||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
||||
parts = re.split(VARIABLE_PATTERN, template)
|
||||
segments = []
|
||||
for part in parts:
|
||||
if '.' in part and (value := variable_pool.get(part.split('.'))):
|
||||
segments.append(value)
|
||||
else:
|
||||
segments.append(factory.build_segment(part))
|
||||
return SegmentGroup(segments=segments)
|
19
api/core/app/segments/segment_group.py
Normal file
19
api/core/app/segments/segment_group.py
Normal file
@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .segments import Segment
|
||||
|
||||
|
||||
class SegmentGroup(BaseModel):
|
||||
segments: list[Segment]
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return ''.join([segment.text for segment in self.segments])
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return ''.join([segment.log for segment in self.segments])
|
||||
|
||||
@property
|
||||
def markdown(self):
|
||||
return ''.join([segment.markdown for segment in self.segments])
|
39
api/core/app/segments/segments.py
Normal file
39
api/core/app/segments/segments.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
value_type: SegmentType
|
||||
value: Any
|
||||
|
||||
@field_validator('value_type')
|
||||
def validate_value_type(cls, value):
|
||||
"""
|
||||
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||
If the value is different, a ValueError is raised.
|
||||
"""
|
||||
if value != cls.model_fields['value_type'].default:
|
||||
raise ValueError("Cannot modify 'value_type'")
|
||||
return value
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class StringSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.STRING
|
||||
value: str
|
17
api/core/app/segments/types.py
Normal file
17
api/core/app/segments/types.py
Normal file
@ -0,0 +1,17 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SegmentType(str, Enum):
|
||||
STRING = 'string'
|
||||
NUMBER = 'number'
|
||||
FILE = 'file'
|
||||
|
||||
SECRET = 'secret'
|
||||
|
||||
OBJECT = 'object'
|
||||
|
||||
ARRAY = 'array'
|
||||
ARRAY_STRING = 'array[string]'
|
||||
ARRAY_NUMBER = 'array[number]'
|
||||
ARRAY_OBJECT = 'array[object]'
|
||||
ARRAY_FILE = 'array[file]'
|
83
api/core/app/segments/variables.py
Normal file
83
api/core/app/segments/variables.py
Normal file
@ -0,0 +1,83 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.helper import encrypter
|
||||
|
||||
from .segments import Segment, StringSegment
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
class Variable(Segment):
|
||||
"""
|
||||
A variable is a segment that has a name.
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
default='',
|
||||
description="Unique identity for variable. It's only used by environment variables now.",
|
||||
)
|
||||
name: str
|
||||
|
||||
|
||||
class StringVariable(StringSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class FloatVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: float
|
||||
|
||||
|
||||
class IntegerVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: int
|
||||
|
||||
|
||||
class ObjectVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Variable]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# TODO: Process variables.
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
# TODO: Process variables.
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
# TODO: Use markdown code block
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
class ArrayVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.ARRAY
|
||||
value: Sequence[Variable]
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||
|
||||
|
||||
class FileVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
# TODO: embed FileVar in this model.
|
||||
value: FileVar
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return self.value.to_markdown()
|
||||
|
||||
|
||||
class SecretVariable(StringVariable):
|
||||
value_type: SegmentType = SegmentType.SECRET
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return encrypter.obfuscated_token(self.value)
|
@ -1,9 +1,11 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, TextIO, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
@ -43,7 +45,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_inputs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
@ -51,8 +53,8 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_outputs: str,
|
||||
tool_inputs: Mapping[str, Any],
|
||||
tool_outputs: Sequence[ToolInvokeMessage],
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
@ -16,7 +17,7 @@ class MessageFileParser:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig,
|
||||
def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileVar]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
||||
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
||||
|
||||
CODE_EXECUTION_TIMEOUT= (10, 60)
|
||||
CODE_EXECUTION_TIMEOUT = (10, 60)
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
pass
|
||||
@ -64,7 +64,7 @@ class CodeExecutor:
|
||||
|
||||
@classmethod
|
||||
def execute_code(cls,
|
||||
language: Literal['python3', 'javascript', 'jinja2'],
|
||||
language: CodeLanguage,
|
||||
preload: str,
|
||||
code: str,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> str:
|
||||
@ -119,7 +119,7 @@ class CodeExecutor:
|
||||
return response.data.stdout
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
|
@ -6,11 +6,16 @@ from models.account import Tenant
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
if not token:
|
||||
return token
|
||||
if len(token) <= 8:
|
||||
return '*' * 20
|
||||
return token[:6] + '*' * 12 + token[-2:]
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f'Tenant with id {tenant_id} not found')
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
|
@ -14,6 +14,9 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_file_name = os.path.join(folder_path, file_name)
|
||||
if not position_file_name or not os.path.exists(position_file_name):
|
||||
return {}
|
||||
|
||||
positions = load_yaml_file(position_file_name, ignore_error=True)
|
||||
position_map = {}
|
||||
index = 0
|
||||
|
@ -64,6 +64,7 @@ User Input:
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||
)
|
||||
|
@ -103,7 +103,7 @@ class TokenBufferMemory:
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and prompt_messages:
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages)>1:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
|
@ -413,6 +413,7 @@ class LBModelManager:
|
||||
for load_balancing_config in self._load_balancing_configs:
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
if not managed_credentials:
|
||||
# FIXME: Mutation to loop iterable `self._load_balancing_configs` during iteration
|
||||
# remove __inherit__ if managed credentials is not provided
|
||||
self._load_balancing_configs.remove(load_balancing_config)
|
||||
else:
|
||||
|
@ -27,9 +27,9 @@ parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
|
@ -113,6 +113,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
if system:
|
||||
extra_model_kwargs['system'] = system
|
||||
|
||||
# Add the new header for claude-3-5-sonnet-20240620 model
|
||||
extra_headers = {}
|
||||
if model == "claude-3-5-sonnet-20240620":
|
||||
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [
|
||||
self._transform_tool_prompt(tool) for tool in tools
|
||||
@ -121,6 +126,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
@ -130,6 +136,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
@ -138,7 +145,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
|
@ -501,7 +501,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
# message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
||||
|
@ -1,6 +1,8 @@
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
- gpt-4-turbo
|
||||
- gpt-4-turbo-2024-04-09
|
||||
- gpt-4-turbo-preview
|
||||
|
@ -0,0 +1,44 @@
|
||||
model: gpt-4o-mini-2024-07-18
|
||||
label:
|
||||
zh_Hans: gpt-4o-mini-2024-07-18
|
||||
en_US: gpt-4o-mini-2024-07-18
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,44 @@
|
||||
model: gpt-4o-mini
|
||||
label:
|
||||
zh_Hans: gpt-4o-mini
|
||||
en_US: gpt-4o-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -1,4 +1,5 @@
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4
|
||||
- openai/gpt-4-32k
|
||||
- openai/gpt-3.5-turbo
|
||||
|
@ -0,0 +1,43 @@
|
||||
model: openai/gpt-4o-mini
|
||||
label:
|
||||
en_US: gpt-4o-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "0.15"
|
||||
output: "0.60"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
Binary file not shown.
After Width: | Height: | Size: 9.2 KiB |
Binary file not shown.
After Width: | Height: | Size: 9.5 KiB |
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
@ -0,0 +1,238 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('access_key')
|
||||
secret_key = credentials.get('secret_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": prompt_messages[0].content,
|
||||
"parameters": { "stop" : stop},
|
||||
"history" : []
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
assistant_text = response_model['Body'].read().decode('utf8')
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, 0, 0)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
|
||||
try:
|
||||
return 0
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度',
|
||||
en_US='Temperature'
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P',
|
||||
en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=credentials.get('context_length', 2048),
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type == LLMMode.CHAT:
|
||||
print(f"completion_type : {LLMMode.CHAT.value}")
|
||||
|
||||
if completion_type == LLMMode.COMPLETION:
|
||||
print(f"completion_type : {LLMMode.COMPLETION.value}")
|
||||
|
||||
features = []
|
||||
|
||||
support_function_call = credentials.get('support_function_call', False)
|
||||
if support_function_call:
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
|
||||
support_vision = credentials.get('support_vision', False)
|
||||
if support_vision:
|
||||
features.append(ModelFeature.VISION)
|
||||
|
||||
context_length = credentials.get('context_length', 2048)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
ModelPropertyKey.CONTEXT_SIZE: context_length
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
@ -0,0 +1,190 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SageMakerRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
|
||||
inputs = [query_input]*len(docs)
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=rerank_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": inputs,
|
||||
"docs": docs
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
scores = json_obj['scores']
|
||||
return scores if isinstance(scores, list) else [scores]
|
||||
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if len(docs) == 0:
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=docs
|
||||
)
|
||||
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
candidate_docs = []
|
||||
|
||||
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
|
||||
for idx in range(len(scores)):
|
||||
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
|
||||
|
||||
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
line = 3
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(candidate_docs):
|
||||
rerank_document = RerankDocument(
|
||||
index=idx,
|
||||
text=result.get('content'),
|
||||
score=result.get('score', -100.0)
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
if rerank_document.score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=rerank_documents
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.RERANK,
|
||||
model_properties={ },
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
@ -0,0 +1,17 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
@ -0,0 +1,125 @@
|
||||
provider: sagemaker
|
||||
label:
|
||||
zh_Hans: Sagemaker
|
||||
en_US: Sagemaker
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
description:
|
||||
en_US: Customized model on Sagemaker
|
||||
zh_Hans: Sagemaker上的私有化部署的模型
|
||||
background: "#ECE9E3"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy customized model on Sagemaker
|
||||
zh_Hans: 如何在Sagemaker上的私有化部署的模型
|
||||
url:
|
||||
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
|
||||
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: sagemaker_endpoint
|
||||
label:
|
||||
en_US: sagemaker endpoint
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 请输出你的Sagemaker推理端点
|
||||
en_US: Enter your Sagemaker Inference endpoint
|
||||
- variable: aws_access_key_id
|
||||
required: false
|
||||
label:
|
||||
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
|
||||
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Access Key
|
||||
zh_Hans: 在此输入您的 Access Key
|
||||
- variable: aws_secret_access_key
|
||||
required: false
|
||||
label:
|
||||
en_US: Secret Access Key
|
||||
zh_Hans: Secret Access Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Secret Access Key
|
||||
zh_Hans: 在此输入您的 Secret Access Key
|
||||
- variable: aws_region
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 地区
|
||||
type: select
|
||||
default: us-east-1
|
||||
options:
|
||||
- value: us-east-1
|
||||
label:
|
||||
en_US: US East (N. Virginia)
|
||||
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||
- value: us-west-2
|
||||
label:
|
||||
en_US: US West (Oregon)
|
||||
zh_Hans: 美国西部 (俄勒冈州)
|
||||
- value: ap-southeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Singapore)
|
||||
zh_Hans: 亚太地区 (新加坡)
|
||||
- value: ap-northeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Tokyo)
|
||||
zh_Hans: 亚太地区 (东京)
|
||||
- value: eu-central-1
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
- value: ap-southeast-2
|
||||
label:
|
||||
en_US: Asia Pacific (Sydney)
|
||||
zh_Hans: 亚太地区 (悉尼)
|
||||
- value: cn-north-1
|
||||
label:
|
||||
en_US: AWS Beijing (cn-north-1)
|
||||
zh_Hans: 中国北京 (cn-north-1)
|
||||
- value: cn-northwest-1
|
||||
label:
|
||||
en_US: AWS Ningxia (cn-northwest-1)
|
||||
zh_Hans: 中国宁夏 (cn-northwest-1)
|
@ -0,0 +1,214 @@
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
BATCH_SIZE = 20
|
||||
CONTEXT_SIZE=8192
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def batch_generator(generator, batch_size):
|
||||
while True:
|
||||
batch = list(itertools.islice(generator, batch_size))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
|
||||
class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
|
||||
response_model = sm_client.invoke_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": content_list,
|
||||
"parameters": {},
|
||||
"is_query" : False,
|
||||
"instruction" : ''
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
embeddings = json_obj['embeddings']
|
||||
return embeddings
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
try:
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
|
||||
line = 3
|
||||
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
|
||||
|
||||
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
|
||||
all_embeddings = []
|
||||
|
||||
line = 4
|
||||
for batch in batches:
|
||||
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
line = 5
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=0 # It's not SAAS API, usage is meaningless
|
||||
)
|
||||
line = 6
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=all_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return 0
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
print("validate_credentials ok....")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
|
||||
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
|
||||
},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
Binary file not shown.
After Width: | Height: | Size: 9.0 KiB |
Binary file not shown.
After Width: | Height: | Size: 1.9 KiB |
@ -0,0 +1,6 @@
|
||||
- step-1-8k
|
||||
- step-1-32k
|
||||
- step-1-128k
|
||||
- step-1-256k
|
||||
- step-1v-8k
|
||||
- step-1v-32k
|
328
api/core/model_runtime/model_providers/stepfun/llm/llm.py
Normal file
328
api/core/model_runtime/model_providers/stepfun/llm/llm.py
Normal file
@ -0,0 +1,328 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get('function_calling_type') == 'tool_call'
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
use_template='temperature',
|
||||
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
use_template='max_tokens',
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get('max_tokens', 1024)),
|
||||
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
use_template='top_p',
|
||||
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.stepfun.com/v1'
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and {
|
||||
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
|
||||
}.intersection(model_schema.features or []):
|
||||
credentials['function_calling_type'] = 'tool_call'
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API format
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = []
|
||||
for function_call in message.tool_calls:
|
||||
message_dict["tool_calls"].append({
|
||||
"id": function_call.id,
|
||||
"type": function_call.type,
|
||||
"function": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments
|
||||
}
|
||||
})
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
|
||||
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
||||
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
||||
function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ''
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||
-> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
finish_reason = "Unknown"
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_name: str):
|
||||
if not tool_name:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id='',
|
||||
type='',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
)
|
||||
break
|
||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json['choices'][0]
|
||||
finish_reason = chunk_json['choices'][0].get('finish_reason')
|
||||
chunk_index += 1
|
||||
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
delta_content = delta.get('content')
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta_content
|
||||
elif 'text' in choice:
|
||||
choice_text = choice.get('text', '')
|
||||
if choice_text == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
# check payload indicator for completion
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(
|
||||
tool_calls=tools_calls,
|
||||
content=""
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason
|
||||
)
|
@ -0,0 +1,25 @@
|
||||
model: step-1-128k
|
||||
label:
|
||||
zh_Hans: step-1-128k
|
||||
en_US: step-1-128k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 128000
|
||||
pricing:
|
||||
input: '0.04'
|
||||
output: '0.20'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,25 @@
|
||||
model: step-1-256k
|
||||
label:
|
||||
zh_Hans: step-1-256k
|
||||
en_US: step-1-256k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 256000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 256000
|
||||
pricing:
|
||||
input: '0.095'
|
||||
output: '0.300'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,28 @@
|
||||
model: step-1-32k
|
||||
label:
|
||||
zh_Hans: step-1-32k
|
||||
en_US: step-1-32k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.070'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,28 @@
|
||||
model: step-1-8k
|
||||
label:
|
||||
zh_Hans: step-1-8k
|
||||
en_US: step-1-8k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8000
|
||||
pricing:
|
||||
input: '0.005'
|
||||
output: '0.020'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,25 @@
|
||||
model: step-1v-32k
|
||||
label:
|
||||
zh_Hans: step-1v-32k
|
||||
en_US: step-1v-32k
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.070'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,25 @@
|
||||
model: step-1v-8k
|
||||
label:
|
||||
zh_Hans: step-1v-8k
|
||||
en_US: step-1v-8k
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.005'
|
||||
output: '0.020'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
30
api/core/model_runtime/model_providers/stepfun/stepfun.py
Normal file
30
api/core/model_runtime/model_providers/stepfun/stepfun.py
Normal file
@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StepfunProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='step-1-8k',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
81
api/core/model_runtime/model_providers/stepfun/stepfun.yaml
Normal file
81
api/core/model_runtime/model_providers/stepfun/stepfun.yaml
Normal file
@ -0,0 +1,81 @@
|
||||
provider: stepfun
|
||||
label:
|
||||
zh_Hans: 阶跃星辰
|
||||
en_US: Stepfun
|
||||
description:
|
||||
en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k
|
||||
zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from stepfun
|
||||
zh_Hans: 从 stepfun 获取 API Key
|
||||
url:
|
||||
en_US: https://platform.stepfun.com/interface-key
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '8192'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
default: '8192'
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not supported
|
||||
zh_Hans: 不支持
|
||||
- value: tool_call
|
||||
label:
|
||||
en_US: Tool Call
|
||||
zh_Hans: Tool Call
|
@ -35,3 +35,4 @@ parameter_rules:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
@ -1,4 +1,4 @@
|
||||
model: ernie-4.0-8k-Latest
|
||||
model: ernie-4.0-8k-latest
|
||||
label:
|
||||
en_US: Ernie-4.0-8K-Latest
|
||||
model_type: llm
|
||||
|
@ -0,0 +1,40 @@
|
||||
model: ernie-4.0-turbo-8k
|
||||
label:
|
||||
en_US: Ernie-4.0-turbo-8K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 2048
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
en_US: Disable Search
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
@ -28,3 +28,4 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
@ -0,0 +1,30 @@
|
||||
model: ernie-character-8k-0321
|
||||
label:
|
||||
en_US: ERNIE-Character-8K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.95
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1.0
|
||||
default: 0.7
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 1024
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
@ -28,3 +28,4 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
@ -28,3 +28,4 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
@ -97,6 +97,7 @@ class BaiduAccessToken:
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class ErnieMessage:
|
||||
class Role(Enum):
|
||||
USER = 'user'
|
||||
@ -137,7 +138,9 @@ class ErnieBotModel:
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
}
|
||||
|
||||
@ -149,7 +152,8 @@ class ErnieBotModel:
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k'
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview'
|
||||
]
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user