mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 21:39:05 +08:00
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/core/app/apps/workflow/app_runner.py # api/core/app/apps/workflow/generate_task_pipeline.py # api/core/app/task_pipeline/workflow_cycle_state_manager.py # api/core/workflow/entities/variable_pool.py # api/core/workflow/nodes/code/code_node.py # api/core/workflow/nodes/llm/llm_node.py # api/core/workflow/nodes/start/start_node.py # api/core/workflow/nodes/variable_assigner/__init__.py # api/tests/integration_tests/workflow/nodes/test_llm.py # api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py # api/tests/unit_tests/core/workflow/nodes/test_answer.py # api/tests/unit_tests/core/workflow/nodes/test_if_else.py # api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
commit
35be41b337
1
.gitignore
vendored
1
.gitignore
vendored
@ -178,3 +178,4 @@ pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
POSITION_TOOL_INCLUDES=
|
||||
POSITION_TOOL_EXCLUDES=
|
||||
|
||||
POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
0
.idea/icon.png → api/.idea/icon.png
generated
0
.idea/icon.png → api/.idea/icon.png
generated
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
0
.idea/vcs.xml → api/.idea/vcs.xml
generated
0
.idea/vcs.xml → api/.idea/vcs.xml
generated
@ -5,8 +5,8 @@
|
||||
"name": "Python: Flask",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"envFile": ".env",
|
||||
"module": "flask",
|
||||
"justMyCode": true,
|
||||
@ -18,15 +18,15 @@
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--port=5001"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"module": "celery",
|
||||
"justMyCode": true,
|
||||
"envFile": ".env",
|
@ -37,6 +37,8 @@ class DifyConfig(
|
||||
|
||||
CODE_MAX_NUMBER: int = 9223372036854775807
|
||||
CODE_MIN_NUMBER: int = -9223372036854775808
|
||||
CODE_MAX_DEPTH: int = 5
|
||||
CODE_MAX_PRECISION: int = 20
|
||||
CODE_MAX_STRING_LENGTH: int = 80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
|
||||
|
@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
Workspace configs
|
||||
@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
|
||||
POSITION_PROVIDER_PINS: str = Field(
|
||||
description='The heads of model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_INCLUDES: str = Field(
|
||||
description='The included model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_EXCLUDES: str = Field(
|
||||
description='The excluded model providers',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_PINS: str = Field(
|
||||
description='The heads of tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_INCLUDES: str = Field(
|
||||
description='The included tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
POSITION_TOOL_EXCLUDES: str = Field(
|
||||
description='The excluded tools',
|
||||
default='',
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
|
||||
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
@computed_field
|
||||
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -466,6 +524,7 @@ class FeatureConfig(
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkspaceConfig,
|
||||
PositionConfig,
|
||||
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.7.0',
|
||||
default='0.7.1',
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
@ -61,6 +61,7 @@ class AppListApi(Resource):
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
@ -94,6 +95,7 @@ class AppImportApi(Resource):
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
@ -167,6 +169,7 @@ class AppApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('max_active_requests', type=int, location='json')
|
||||
@ -208,6 +211,7 @@ class AppCopyApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon_type', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
@ -154,6 +154,8 @@ class ChatConversationApi(Resource):
|
||||
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
|
||||
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
||||
required=False, default='-updated_at', location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
subquery = (
|
||||
@ -225,7 +227,17 @@ class ChatConversationApi(Resource):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
match args['sort_by']:
|
||||
case 'created_at':
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case '-created_at':
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
case 'updated_at':
|
||||
query = query.order_by(Conversation.updated_at.asc())
|
||||
case '-updated_at':
|
||||
query = query.order_by(Conversation.updated_at.desc())
|
||||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(
|
||||
query,
|
||||
|
@ -16,6 +16,7 @@ from models.model import Site
|
||||
def parse_app_site_args():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('title', type=str, required=False, location='json')
|
||||
parser.add_argument('icon_type', type=str, required=False, location='json')
|
||||
parser.add_argument('icon', type=str, required=False, location='json')
|
||||
parser.add_argument('icon_background', type=str, required=False, location='json')
|
||||
parser.add_argument('description', type=str, required=False, location='json')
|
||||
@ -53,6 +54,7 @@ class AppSite(Resource):
|
||||
|
||||
for attr_name in [
|
||||
'title',
|
||||
'icon_type',
|
||||
'icon',
|
||||
'icon_background',
|
||||
'description',
|
||||
|
@ -460,6 +460,7 @@ class ConvertToWorkflowApi(Resource):
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
@ -573,13 +573,13 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
|
@ -25,6 +25,8 @@ class ConversationApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
||||
required=False, default='-updated_at', location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -33,7 +35,8 @@ class ConversationApi(Resource):
|
||||
user=end_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
invoke_from=InvokeFrom.SERVICE_API
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args['sort_by']
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
@ -53,19 +53,22 @@ class SegmentApi(DatasetApiResource):
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
for args_item in args['segments']:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
if args['segments'] is not None:
|
||||
for args_item in args['segments']:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
else:
|
||||
return {"error": "Segemtns is required"}, 400
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
|
@ -26,6 +26,8 @@ class ConversationListApi(WebApiResource):
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
|
||||
required=False, default='-updated_at', location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
@ -40,6 +42,7 @@ class ConversationListApi(WebApiResource):
|
||||
limit=args['limit'],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args['sort_by']
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
@ -6,6 +6,7 @@ from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
from models.account import TenantStatus
|
||||
from models.model import Site
|
||||
from services.feature_service import FeatureService
|
||||
@ -28,8 +29,10 @@ class AppSiteApi(WebApiResource):
|
||||
'title': fields.String,
|
||||
'chat_color_theme': fields.String,
|
||||
'chat_color_theme_inverted': fields.Boolean,
|
||||
'icon_type': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'icon_url': AppIconUrlField,
|
||||
'description': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
|
@ -64,15 +64,19 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param app_config: app generate entity
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
:param prompt_messages: prompt messages
|
||||
:param variables_pool: variables pool
|
||||
:param db_variables: db variables
|
||||
:param model_instance: model instance
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@ -445,7 +449,7 @@ class BaseAgentRunner(AppRunner):
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
tool_responses = { tool: agent_thought.observation for tool in tools }
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
|
@ -292,6 +292,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
handle invoke action
|
||||
:param action: action
|
||||
:param tool_instances: tool instances
|
||||
:param message_file_ids: message file ids
|
||||
:param trace_manager: trace manager
|
||||
:return: observation, meta
|
||||
"""
|
||||
# action is tool call, invoke tool
|
||||
|
@ -93,7 +93,7 @@ class DatasetConfigManager:
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs["reranking_mode"],
|
||||
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
external_data_variables = []
|
||||
variables = []
|
||||
variable_entities = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = config.get('external_data_tools', [])
|
||||
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
|
||||
)
|
||||
|
||||
# variables and external_data_tools
|
||||
for variable in config.get('user_input_form', []):
|
||||
typ = list(variable.keys())[0]
|
||||
if typ == 'external_data_tool':
|
||||
val = variable[typ]
|
||||
if 'config' not in val:
|
||||
for variables in config.get('user_input_form', []):
|
||||
variable_type = list(variables.keys())[0]
|
||||
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||
variable = variables[variable_type]
|
||||
if 'config' not in variable:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=val['variable'],
|
||||
type=val['type'],
|
||||
config=val['config']
|
||||
variable=variable['variable'],
|
||||
type=variable['type'],
|
||||
config=variable['config']
|
||||
)
|
||||
)
|
||||
elif typ in [
|
||||
VariableEntity.Type.TEXT_INPUT.value,
|
||||
VariableEntity.Type.PARAGRAPH.value,
|
||||
VariableEntity.Type.NUMBER.value,
|
||||
elif variable_type in [
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.SELECT,
|
||||
]:
|
||||
variables.append(
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=VariableEntity.Type.value_of(typ),
|
||||
variable=variable[typ].get('variable'),
|
||||
description=variable[typ].get('description'),
|
||||
label=variable[typ].get('label'),
|
||||
required=variable[typ].get('required', False),
|
||||
max_length=variable[typ].get('max_length'),
|
||||
default=variable[typ].get('default'),
|
||||
)
|
||||
)
|
||||
elif typ == VariableEntity.Type.SELECT.value:
|
||||
variables.append(
|
||||
VariableEntity(
|
||||
type=VariableEntity.Type.SELECT,
|
||||
variable=variable[typ].get('variable'),
|
||||
description=variable[typ].get('description'),
|
||||
label=variable[typ].get('label'),
|
||||
required=variable[typ].get('required', False),
|
||||
options=variable[typ].get('options'),
|
||||
default=variable[typ].get('default'),
|
||||
type=variable_type,
|
||||
variable=variable.get('variable'),
|
||||
description=variable.get('description'),
|
||||
label=variable.get('label'),
|
||||
required=variable.get('required', False),
|
||||
max_length=variable.get('max_length'),
|
||||
options=variable.get('options'),
|
||||
default=variable.get('default'),
|
||||
)
|
||||
)
|
||||
|
||||
return variables, external_data_variables
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
|
||||
config=config
|
||||
)
|
||||
|
||||
return config, ["external_data_tools"]
|
||||
return config, ["external_data_tools"]
|
||||
|
@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
|
||||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||
|
||||
|
||||
class VariableEntityType(str, Enum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external-data-tool"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Variable Entity.
|
||||
"""
|
||||
class Type(Enum):
|
||||
TEXT_INPUT = 'text-input'
|
||||
SELECT = 'select'
|
||||
PARAGRAPH = 'paragraph'
|
||||
NUMBER = 'number'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'VariableEntity.Type':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid variable type value {value}')
|
||||
|
||||
variable: str
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
type: Type
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Optional[list[str]] = None
|
||||
default: Optional[str] = None
|
||||
hint: Optional[str] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.variable
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
workflow_id: str
|
||||
workflow_id: str
|
||||
|
@ -23,6 +23,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
@ -67,8 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id', ''), user)
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
@ -225,6 +228,62 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id=message.id
|
||||
)
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
conversation.dialogue_count += 1
|
||||
|
||||
conversation_id = conversation.id
|
||||
conversation_dialogue_count = conversation.dialogue_count
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
contexts.workflow_variable_pool.set(variable_pool)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
@ -296,7 +355,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -47,7 +47,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -69,7 +69,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -102,10 +102,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
@ -312,7 +312,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -321,7 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
@ -9,29 +9,29 @@ class BaseAppGenerator:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||
return filtered_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.name)
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f'{var.name} is required in input form')
|
||||
raise ValueError(f'{var.variable} is required in input form')
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ''
|
||||
if (
|
||||
var.type
|
||||
in (
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.SELECT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
)
|
||||
and user_input_value
|
||||
and not isinstance(user_input_value, str)
|
||||
):
|
||||
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
|
||||
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if '.' in user_input_value:
|
||||
@ -39,14 +39,14 @@ class BaseAppGenerator:
|
||||
else:
|
||||
return int(user_input_value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.name} in input form must be a valid number")
|
||||
if var.type == VariableEntity.Type.SELECT:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options or []
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
|
||||
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
|
||||
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
|
||||
|
||||
return user_input_value
|
||||
|
||||
|
@ -256,6 +256,7 @@ class AppRunner:
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param stream: stream
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
@ -278,6 +279,7 @@ class AppRunner:
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish(
|
||||
@ -293,6 +295,7 @@ class AppRunner:
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import and_
|
||||
@ -36,17 +37,17 @@ logger = logging.getLogger(__name__)
|
||||
class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _handle_response(
|
||||
self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
@ -138,6 +139,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
:conversation conversation
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
@ -192,6 +194,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
app_id=app_config.app_id,
|
||||
|
@ -13,7 +13,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
@ -79,14 +79,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
@ -98,7 +98,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
|
@ -41,7 +41,9 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@ -64,7 +66,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
@ -88,8 +90,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
@ -4,7 +4,7 @@ from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import Literal, Optional
|
||||
|
||||
from httpx import get, post
|
||||
from httpx import Timeout, get, post
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
||||
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
||||
|
||||
CODE_EXECUTION_TIMEOUT = (10, 60)
|
||||
CODE_EXECUTION_TIMEOUT = Timeout(connect=10, write=10, read=60, pool=None)
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
pass
|
||||
@ -116,7 +116,7 @@ class CodeExecutor:
|
||||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
|
||||
return response.data.stdout
|
||||
return response.data.stdout or ''
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
|
@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
def main(arg1: int, arg2: int) -> dict:
|
||||
def main(arg1: str, arg2: str) -> dict:
|
||||
return {
|
||||
"result": arg1 + arg2,
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
||||
return {name: index for index, name in enumerate(positions)}
|
||||
|
||||
|
||||
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping for tools from name to index from a YAML file.
|
||||
:param folder_path:
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
|
||||
)
|
||||
|
||||
|
||||
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping for providers from name to index from a YAML file.
|
||||
:param folder_path:
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
|
||||
)
|
||||
|
||||
|
||||
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Pin the items in the pin list to the beginning of the position map.
|
||||
Overall logic: exclude > include > pin
|
||||
:param position_map: the position map to be sorted and filtered
|
||||
:param pin_list: the list of pins to be put at the beginning
|
||||
:return: the sorted position map
|
||||
"""
|
||||
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
|
||||
|
||||
# Add pins to position map
|
||||
position_map = {name: idx for idx, name in enumerate(pin_list)}
|
||||
|
||||
# Add remaining positions to position map
|
||||
start_idx = len(position_map)
|
||||
for name in positions:
|
||||
if name not in position_map:
|
||||
position_map[name] = start_idx
|
||||
start_idx += 1
|
||||
|
||||
return position_map
|
||||
|
||||
|
||||
def is_filtered(
|
||||
include_set: set[str],
|
||||
exclude_set: set[str],
|
||||
data: Any,
|
||||
name_func: Callable[[Any], str],
|
||||
) -> bool:
|
||||
"""
|
||||
Chcek if the object should be filtered out.
|
||||
Overall logic: exclude > include > pin
|
||||
:param include_set: the set of names to be included
|
||||
:param exclude_set: the set of names to be excluded
|
||||
:param name_func: the function to get the name of the object
|
||||
:param data: the data to be filtered
|
||||
:return: True if the object should be filtered out, False otherwise
|
||||
"""
|
||||
if not data:
|
||||
return False
|
||||
if not include_set and not exclude_set:
|
||||
return False
|
||||
|
||||
name = name_func(data)
|
||||
|
||||
if name in exclude_set: # exclude_set is prioritized
|
||||
return True
|
||||
if include_set and name not in include_set: # filter out only if include_set is not empty
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
position_map: dict[str, int],
|
||||
data: list[Any],
|
||||
|
@ -700,6 +700,7 @@ class IndexingRunner:
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
|
||||
DatasetDocument.error: None,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -271,9 +271,8 @@ class ModelInstance:
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param tenant_id: user tenant id
|
||||
:param user: unique user id
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
@ -369,6 +368,15 @@ class ModelManager:
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||
"""
|
||||
Return first provider and the first model in the provider
|
||||
:param tenant_id: tenant id
|
||||
:param model_type: model type
|
||||
:return: provider name, model name
|
||||
"""
|
||||
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
|
||||
|
||||
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
||||
"""
|
||||
Get default model instance
|
||||
@ -401,6 +409,10 @@ class LBModelManager:
|
||||
managed_credentials: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Load balancing model manager
|
||||
:param tenant_id: tenant_id
|
||||
:param provider: provider
|
||||
:param model_type: model_type
|
||||
:param model: model name
|
||||
:param load_balancing_configs: all load balancing configurations
|
||||
:param managed_credentials: credentials if load balancing configuration name is __inherit__
|
||||
"""
|
||||
@ -499,7 +511,6 @@ class LBModelManager:
|
||||
config.id
|
||||
)
|
||||
|
||||
|
||||
res = redis_client.exists(cooldown_cache_key)
|
||||
res = cast(bool, res)
|
||||
return res
|
||||
|
@ -151,9 +151,9 @@ class AIModel(ABC):
|
||||
os.path.join(provider_model_type_path, model_schema_yaml)
|
||||
for model_schema_yaml in os.listdir(provider_model_type_path)
|
||||
if not model_schema_yaml.startswith('__')
|
||||
and not model_schema_yaml.startswith('_')
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith('.yaml')
|
||||
and not model_schema_yaml.startswith('_')
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith('.yaml')
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
|
@ -185,7 +185,7 @@ if you are not sure about the structure.
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
@ -249,10 +249,10 @@ if you are not sure about the structure.
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=new_generator()
|
||||
)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
||||
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
||||
input_generator: Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
@ -310,7 +310,7 @@ if you are not sure about the structure.
|
||||
)
|
||||
)
|
||||
|
||||
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
||||
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
||||
input_generator: Generator[LLMResultChunk, None, None]) \
|
||||
-> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
@ -470,7 +470,7 @@ if you are not sure about the structure.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
@ -792,6 +792,13 @@ if you are not sure about the structure.
|
||||
if not isinstance(parameter_value, str):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be string.")
|
||||
|
||||
# validate options
|
||||
if parameter_rule.options and parameter_value not in parameter_rule.options:
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
|
||||
elif parameter_rule.type == ParameterType.TEXT:
|
||||
if not isinstance(parameter_value, str):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be text.")
|
||||
|
||||
# validate options
|
||||
if parameter_rule.options and parameter_value not in parameter_rule.options:
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
|
||||
|
@ -70,7 +70,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
# max font is 4096,there is 3500 limit for each request
|
||||
# max length is 4096 characters, there is 3500 limit for each request
|
||||
max_length = 3500
|
||||
if len(content_text) > max_length:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
|
||||
|
@ -6,7 +6,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
@ -234,7 +234,7 @@ class ModelProviderFactory:
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
position_map = get_position_map(model_providers_path)
|
||||
position_map = get_provider_position_map(model_providers_path)
|
||||
|
||||
# traverse all model_provider_dir_paths
|
||||
model_providers: list[ModelProviderExtension] = []
|
||||
|
@ -84,7 +84,8 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
@ -31,6 +31,14 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: Base URL, 如:https://api.moonshot.cn/v1
|
||||
en_US: Base URL, e.g. https://api.moonshot.cn/v1
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
|
@ -37,6 +37,9 @@ parameter_rules:
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- json_schema
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
|
@ -0,0 +1,44 @@
|
||||
model: gpt-4o-2024-08-06
|
||||
label:
|
||||
zh_Hans: gpt-4o-2024-08-06
|
||||
en_US: gpt-4o-2024-08-06
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '2.50'
|
||||
output: '10.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,4 @@
|
||||
model: netease-youdao/bce-reranker-base_v1
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 512
|
@ -0,0 +1,4 @@
|
||||
model: BAAI/bge-reranker-v2-m3
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 8192
|
@ -0,0 +1,87 @@
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class SiliconflowRerankModel(RerankModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, query: str, docs: list[str],
|
||||
score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) -> RerankResult:
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1')
|
||||
if base_url.endswith('/'):
|
||||
base_url = base_url[:-1]
|
||||
try:
|
||||
response = httpx.post(
|
||||
base_url + '/rerank',
|
||||
json={
|
||||
"model": model,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": top_n,
|
||||
"return_documents": True
|
||||
},
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results['results']:
|
||||
rerank_document = RerankDocument(
|
||||
index=result['index'],
|
||||
text=result['document']['text'],
|
||||
score=result['relevance_score'],
|
||||
)
|
||||
if score_threshold is None or result['relevance_score'] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError]
|
||||
}
|
@ -12,10 +12,11 @@ help:
|
||||
en_US: Get your API Key from SiliconFlow
|
||||
zh_Hans: 从 SiliconFlow 获取 API Key
|
||||
url:
|
||||
en_US: https://cloud.siliconflow.cn/keys
|
||||
en_US: https://cloud.siliconflow.cn/account/ak
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
|
@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
|
||||
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
|
||||
get_model_config,
|
||||
get_v2_req_params,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
-> LLMResult | Generator:
|
||||
|
||||
client = MaaSClient.from_credential(credentials)
|
||||
|
||||
req_params = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('req_params', {}).copy()
|
||||
if credentials.get('context_size'):
|
||||
req_params['max_prompt_tokens'] = credentials.get('context_size')
|
||||
if credentials.get('max_tokens'):
|
||||
req_params['max_new_tokens'] = credentials.get('max_tokens')
|
||||
if model_parameters.get('max_tokens'):
|
||||
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
|
||||
if model_parameters.get('temperature'):
|
||||
req_params['temperature'] = model_parameters.get('temperature')
|
||||
if model_parameters.get('top_p'):
|
||||
req_params['top_p'] = model_parameters.get('top_p')
|
||||
if model_parameters.get('top_k'):
|
||||
req_params['top_k'] = model_parameters.get('top_k')
|
||||
if model_parameters.get('presence_penalty'):
|
||||
req_params['presence_penalty'] = model_parameters.get(
|
||||
'presence_penalty')
|
||||
if model_parameters.get('frequency_penalty'):
|
||||
req_params['frequency_penalty'] = model_parameters.get(
|
||||
'frequency_penalty')
|
||||
if stop:
|
||||
req_params['stop'] = stop
|
||||
|
||||
req_params = get_v2_req_params(credentials, model_parameters, stop)
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [
|
||||
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
|
||||
]
|
||||
|
||||
resp = MaaSClient.wrap_exception(
|
||||
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
||||
if not stream:
|
||||
@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
max_tokens = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
|
||||
if credentials.get('max_tokens'):
|
||||
max_tokens = int(credentials.get('max_tokens'))
|
||||
model_config = get_model_config(credentials)
|
||||
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
name='presence_penalty',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='presence_penalty',
|
||||
label={
|
||||
'en_US': 'Presence Penalty',
|
||||
'zh_Hans': '存在惩罚',
|
||||
},
|
||||
label=I18nObject(
|
||||
en_US='Presence Penalty',
|
||||
zh_Hans= '存在惩罚',
|
||||
),
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
),
|
||||
@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
name='frequency_penalty',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='frequency_penalty',
|
||||
label={
|
||||
'en_US': 'Frequency Penalty',
|
||||
'zh_Hans': '频率惩罚',
|
||||
},
|
||||
label=I18nObject(
|
||||
en_US= 'Frequency Penalty',
|
||||
zh_Hans= '频率惩罚',
|
||||
),
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
),
|
||||
@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=max_tokens,
|
||||
max=model_config.properties.max_tokens,
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
@ -266,17 +242,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
]
|
||||
|
||||
model_properties = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('model_properties', {}).copy()
|
||||
if credentials.get('mode'):
|
||||
model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
|
||||
if credentials.get('context_size'):
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||
credentials.get('context_size', 4096))
|
||||
|
||||
model_features = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('features', [])
|
||||
|
||||
model_properties = {}
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
|
||||
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
model_type=ModelType.LLM,
|
||||
model_properties=model_properties,
|
||||
parameter_rules=rules,
|
||||
features=model_features,
|
||||
features=model_config.features,
|
||||
)
|
||||
|
||||
return entity
|
||||
|
@ -1,181 +1,123 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
ModelConfigs = {
|
||||
'Doubao-pro-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-pro-32k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 32768,
|
||||
'max_new_tokens': 32768,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-32k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 32768,
|
||||
'max_new_tokens': 32768,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-pro-128k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 131072,
|
||||
'max_new_tokens': 131072,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-128k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 131072,
|
||||
'max_new_tokens': 131072,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Skylark2-pro-4k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 4096,
|
||||
'max_new_tokens': 4000,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Llama3-8B': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 8192,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Llama3-70B': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 8192,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Moonshot-v1-8k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Moonshot-v1-32k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 32768,
|
||||
'max_new_tokens': 16384,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Moonshot-v1-128k': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 131072,
|
||||
'max_new_tokens': 65536,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'GLM3-130B': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'GLM3-130B-Fin': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 4096,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Mistral-7B': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 2048,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
},
|
||||
'features': [],
|
||||
}
|
||||
|
||||
class ModelProperties(BaseModel):
|
||||
context_size: int
|
||||
max_tokens: int
|
||||
mode: LLMMode
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
properties: ModelProperties
|
||||
features: list[ModelFeature]
|
||||
|
||||
|
||||
configs: dict[str, ModelConfig] = {
|
||||
'Doubao-pro-4k': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-lite-4k': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-pro-32k': ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-lite-32k': ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-pro-128k': ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Doubao-lite-128k': ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL]
|
||||
),
|
||||
'Skylark2-pro-4k': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Llama3-8B': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Llama3-70B': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Moonshot-v1-8k': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Moonshot-v1-32k': ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Moonshot-v1-128k': ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'GLM3-130B': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'GLM3-130B-Fin': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
),
|
||||
'Mistral-7B': ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
|
||||
features=[]
|
||||
)
|
||||
}
|
||||
|
||||
def get_model_config(credentials: dict)->ModelConfig:
|
||||
base_model = credentials.get('base_model_name', '')
|
||||
model_configs = configs.get(base_model)
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get('context_size', 0)),
|
||||
max_tokens=int(credentials.get('max_tokens', 0)),
|
||||
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
|
||||
),
|
||||
features=[]
|
||||
)
|
||||
return model_configs
|
||||
|
||||
|
||||
def get_v2_req_params(credentials: dict, model_parameters: dict,
|
||||
stop: list[str] | None=None):
|
||||
req_params = {}
|
||||
# predefined properties
|
||||
model_configs = get_model_config(credentials)
|
||||
if model_configs:
|
||||
req_params['max_prompt_tokens'] = model_configs.properties.context_size
|
||||
req_params['max_new_tokens'] = model_configs.properties.max_tokens
|
||||
|
||||
# model parameters
|
||||
if model_parameters.get('max_tokens'):
|
||||
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
|
||||
if model_parameters.get('temperature'):
|
||||
req_params['temperature'] = model_parameters.get('temperature')
|
||||
if model_parameters.get('top_p'):
|
||||
req_params['top_p'] = model_parameters.get('top_p')
|
||||
if model_parameters.get('top_k'):
|
||||
req_params['top_k'] = model_parameters.get('top_k')
|
||||
if model_parameters.get('presence_penalty'):
|
||||
req_params['presence_penalty'] = model_parameters.get(
|
||||
'presence_penalty')
|
||||
if model_parameters.get('frequency_penalty'):
|
||||
req_params['frequency_penalty'] = model_parameters.get(
|
||||
'frequency_penalty')
|
||||
|
||||
if stop:
|
||||
req_params['stop'] = stop
|
||||
|
||||
return req_params
|
@ -1,9 +1,27 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelProperties(BaseModel):
|
||||
context_size: int
|
||||
max_chunks: int
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
properties: ModelProperties
|
||||
|
||||
ModelConfigs = {
|
||||
'Doubao-embedding': {
|
||||
'req_params': {},
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'max_chunks': 1,
|
||||
}
|
||||
},
|
||||
'Doubao-embedding': ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_chunks=1)
|
||||
),
|
||||
}
|
||||
|
||||
def get_model_config(credentials: dict)->ModelConfig:
|
||||
base_model = credentials.get('base_model_name', '')
|
||||
model_configs = ModelConfigs.get(base_model)
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get('context_size', 0)),
|
||||
max_chunks=int(credentials.get('max_chunks', 0)),
|
||||
)
|
||||
)
|
||||
return model_configs
|
@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
|
||||
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
|
||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
||||
|
||||
|
||||
@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
model_properties = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('model_properties', {}).copy()
|
||||
if credentials.get('context_size'):
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||
credentials.get('context_size', 4096))
|
||||
if credentials.get('max_chunks'):
|
||||
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
|
||||
credentials.get('max_chunks', 4096))
|
||||
model_config = get_model_config(credentials)
|
||||
model_properties = {}
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
|
||||
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
|
198
api/core/model_runtime/model_providers/wenxin/_common.py
Normal file
198
api/core/model_runtime/model_providers/wenxin/_common.py
Normal file
@ -0,0 +1,198 @@
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
|
||||
baidu_access_tokens_lock = Lock()
|
||||
|
||||
|
||||
class BaiduAccessToken:
|
||||
api_key: str
|
||||
access_token: str
|
||||
expires: datetime
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.access_token = ''
|
||||
self.expires = datetime.now() + timedelta(days=3)
|
||||
|
||||
@staticmethod
|
||||
def _get_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
request access token from Baidu
|
||||
"""
|
||||
try:
|
||||
response = post(
|
||||
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
|
||||
|
||||
resp = response.json()
|
||||
if 'error' in resp:
|
||||
if resp['error'] == 'invalid_client':
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
elif resp['error'] == 'unknown_error':
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
elif resp['error'] == 'invalid_request':
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
elif resp['error'] == 'rate_limit_exceeded':
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
else:
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
|
||||
return resp['access_token']
|
||||
|
||||
@staticmethod
|
||||
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
|
||||
"""
|
||||
LLM from Baidu requires access token to invoke the API.
|
||||
however, we have api_key and secret_key, and access token is valid for 30 days.
|
||||
so we can cache the access token for 3 days. (avoid memory leak)
|
||||
|
||||
it may be more efficient to use a ticker to refresh access token, but it will cause
|
||||
more complexity, so we just refresh access tokens when get_access_token is called.
|
||||
"""
|
||||
|
||||
# loop up cache, remove expired access token
|
||||
baidu_access_tokens_lock.acquire()
|
||||
now = datetime.now()
|
||||
for key in list(baidu_access_tokens.keys()):
|
||||
token = baidu_access_tokens[key]
|
||||
if token.expires < now:
|
||||
baidu_access_tokens.pop(key)
|
||||
|
||||
if api_key not in baidu_access_tokens:
|
||||
# if access token not in cache, request it
|
||||
token = BaiduAccessToken(api_key)
|
||||
baidu_access_tokens[api_key] = token
|
||||
# release it to enhance performance
|
||||
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
|
||||
baidu_access_tokens_lock.release()
|
||||
# try to get access token
|
||||
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
token.access_token = token_str
|
||||
token.expires = now + timedelta(days=3)
|
||||
return token
|
||||
else:
|
||||
# if access token in cache, return it
|
||||
token = baidu_access_tokens[api_key]
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class _CommonWenxin:
|
||||
api_bases = {
|
||||
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
||||
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
||||
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
||||
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
|
||||
'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en',
|
||||
'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh',
|
||||
'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
'ernie-bot',
|
||||
'ernie-bot-8k',
|
||||
'ernie-3.5-8k',
|
||||
'ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
secret_key: str = ''
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str):
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
@staticmethod
|
||||
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['api_key'],
|
||||
"secret_key": credentials['secret_key']
|
||||
}
|
||||
return credentials_kwargs
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
error_map = {
|
||||
1: InternalServerError,
|
||||
2: InternalServerError,
|
||||
3: BadRequestError,
|
||||
4: RateLimitReachedError,
|
||||
6: InvalidAuthenticationError,
|
||||
13: InvalidAPIKeyError,
|
||||
14: InvalidAPIKeyError,
|
||||
15: InvalidAPIKeyError,
|
||||
17: RateLimitReachedError,
|
||||
18: RateLimitReachedError,
|
||||
19: RateLimitReachedError,
|
||||
100: InvalidAPIKeyError,
|
||||
111: InvalidAPIKeyError,
|
||||
200: InternalServerError,
|
||||
336000: InternalServerError,
|
||||
336001: BadRequestError,
|
||||
336002: BadRequestError,
|
||||
336003: BadRequestError,
|
||||
336004: InvalidAuthenticationError,
|
||||
336005: InvalidAPIKeyError,
|
||||
336006: BadRequestError,
|
||||
336007: BadRequestError,
|
||||
336008: BadRequestError,
|
||||
336100: InternalServerError,
|
||||
336101: BadRequestError,
|
||||
336102: BadRequestError,
|
||||
336103: BadRequestError,
|
||||
336104: BadRequestError,
|
||||
336105: BadRequestError,
|
||||
336200: InternalServerError,
|
||||
336303: BadRequestError,
|
||||
337006: BadRequestError
|
||||
}
|
||||
|
||||
if code in error_map:
|
||||
raise error_map[code](msg)
|
||||
else:
|
||||
raise InternalServerError(f'Unknown error: {msg}')
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
||||
return token.access_token
|
@ -1,102 +1,17 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from json import dumps, loads
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
|
||||
from requests import Response, post
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
||||
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
# map api_key to access_token
|
||||
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
|
||||
baidu_access_tokens_lock = Lock()
|
||||
|
||||
class BaiduAccessToken:
|
||||
api_key: str
|
||||
access_token: str
|
||||
expires: datetime
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.access_token = ''
|
||||
self.expires = datetime.now() + timedelta(days=3)
|
||||
|
||||
def _get_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
request access token from Baidu
|
||||
"""
|
||||
try:
|
||||
response = post(
|
||||
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
|
||||
|
||||
resp = response.json()
|
||||
if 'error' in resp:
|
||||
if resp['error'] == 'invalid_client':
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
elif resp['error'] == 'unknown_error':
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
elif resp['error'] == 'invalid_request':
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
elif resp['error'] == 'rate_limit_exceeded':
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
else:
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
|
||||
return resp['access_token']
|
||||
|
||||
@staticmethod
|
||||
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
|
||||
"""
|
||||
LLM from Baidu requires access token to invoke the API.
|
||||
however, we have api_key and secret_key, and access token is valid for 30 days.
|
||||
so we can cache the access token for 3 days. (avoid memory leak)
|
||||
|
||||
it may be more efficient to use a ticker to refresh access token, but it will cause
|
||||
more complexity, so we just refresh access tokens when get_access_token is called.
|
||||
"""
|
||||
|
||||
# loop up cache, remove expired access token
|
||||
baidu_access_tokens_lock.acquire()
|
||||
now = datetime.now()
|
||||
for key in list(baidu_access_tokens.keys()):
|
||||
token = baidu_access_tokens[key]
|
||||
if token.expires < now:
|
||||
baidu_access_tokens.pop(key)
|
||||
|
||||
if api_key not in baidu_access_tokens:
|
||||
# if access token not in cache, request it
|
||||
token = BaiduAccessToken(api_key)
|
||||
baidu_access_tokens[api_key] = token
|
||||
# release it to enhance performance
|
||||
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
|
||||
baidu_access_tokens_lock.release()
|
||||
# try to get access token
|
||||
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
token.access_token = token_str
|
||||
token.expires = now + timedelta(days=3)
|
||||
return token
|
||||
else:
|
||||
# if access token in cache, return it
|
||||
token = baidu_access_tokens[api_key]
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class ErnieMessage:
|
||||
class Role(Enum):
|
||||
@ -120,51 +35,7 @@ class ErnieMessage:
|
||||
self.content = content
|
||||
self.role = role
|
||||
|
||||
class ErnieBotModel:
|
||||
api_bases = {
|
||||
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
||||
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
||||
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
'ernie-bot',
|
||||
'ernie-bot-8k',
|
||||
'ernie-3.5-8k',
|
||||
'ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
secret_key: str = ''
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str):
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
class ErnieBotModel(_CommonWenxin):
|
||||
|
||||
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
|
||||
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
|
||||
@ -199,51 +70,6 @@ class ErnieBotModel:
|
||||
return self._handle_chat_stream_generate_response(resp)
|
||||
return self._handle_chat_generate_response(resp)
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
error_map = {
|
||||
1: InternalServerError,
|
||||
2: InternalServerError,
|
||||
3: BadRequestError,
|
||||
4: RateLimitReachedError,
|
||||
6: InvalidAuthenticationError,
|
||||
13: InvalidAPIKeyError,
|
||||
14: InvalidAPIKeyError,
|
||||
15: InvalidAPIKeyError,
|
||||
17: RateLimitReachedError,
|
||||
18: RateLimitReachedError,
|
||||
19: RateLimitReachedError,
|
||||
100: InvalidAPIKeyError,
|
||||
111: InvalidAPIKeyError,
|
||||
200: InternalServerError,
|
||||
336000: InternalServerError,
|
||||
336001: BadRequestError,
|
||||
336002: BadRequestError,
|
||||
336003: BadRequestError,
|
||||
336004: InvalidAuthenticationError,
|
||||
336005: InvalidAPIKeyError,
|
||||
336006: BadRequestError,
|
||||
336007: BadRequestError,
|
||||
336008: BadRequestError,
|
||||
336100: InternalServerError,
|
||||
336101: BadRequestError,
|
||||
336102: BadRequestError,
|
||||
336103: BadRequestError,
|
||||
336104: BadRequestError,
|
||||
336105: BadRequestError,
|
||||
336200: InternalServerError,
|
||||
336303: BadRequestError,
|
||||
337006: BadRequestError
|
||||
}
|
||||
|
||||
if code in error_map:
|
||||
raise error_map[code](msg)
|
||||
else:
|
||||
raise InternalServerError(f'Unknown error: {msg}')
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
||||
return token.access_token
|
||||
|
||||
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
|
||||
return [ErnieMessage(message.content, message.role) for message in messages]
|
||||
|
||||
|
@ -1,17 +0,0 @@
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
pass
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
|
||||
|
||||
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
try:
|
||||
BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
return invoke_error_mapping()
|
||||
|
@ -0,0 +1,9 @@
|
||||
model: bge-large-en
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,9 @@
|
||||
model: bge-large-zh
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,9 @@
|
||||
model: embedding-v1
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,9 @@
|
||||
model: tao-8k
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
max_chunks: 1
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,184 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from json import dumps
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from requests import Response, post
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
invoke_error_mapping,
|
||||
)
|
||||
|
||||
|
||||
class TextEmbedding:
|
||||
@abstractmethod
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
access_token = self._get_access_token()
|
||||
url = f'{self.api_bases[model]}?access_token={access_token}'
|
||||
body = self._build_embed_request_body(model, texts, user)
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
resp = post(url, data=dumps(body), headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
|
||||
return self._handle_embed_response(model, resp)
|
||||
|
||||
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
|
||||
if len(texts) == 0:
|
||||
raise BadRequestError('The number of texts should not be zero.')
|
||||
body = {
|
||||
'input': texts,
|
||||
'user_id': user,
|
||||
}
|
||||
return body
|
||||
|
||||
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
|
||||
data = response.json()
|
||||
if 'error_code' in data:
|
||||
code = data['error_code']
|
||||
msg = data['error_msg']
|
||||
# raise error
|
||||
self._handle_error(code, msg)
|
||||
|
||||
embeddings = [v['embedding'] for v in data['data']]
|
||||
_usage = data['usage']
|
||||
tokens = _usage['prompt_tokens']
|
||||
total_tokens = _usage['total_tokens']
|
||||
|
||||
return embeddings, tokens, total_tokens
|
||||
|
||||
|
||||
class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
||||
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return WenxinTextEmbedding(api_key, secret_key)
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, texts: list[str],
|
||||
user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
|
||||
user = user if user else 'ErnieBotDefault'
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
used_total_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
for i in _iter:
|
||||
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
|
||||
model,
|
||||
inputs[i: i + max_chunks],
|
||||
user)
|
||||
used_tokens += _used_tokens
|
||||
used_total_tokens += _total_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
|
||||
return TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
try:
|
||||
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return invoke_error_mapping()
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=total_tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
@ -17,6 +17,7 @@ help:
|
||||
en_US: https://cloud.baidu.com/wenxin.html
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
@ -0,0 +1,57 @@
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
pass
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
tools=tools, stop=stop, stream=stream, user=user,
|
||||
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
model_uid=credentials['model_uid'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
)
|
||||
|
||||
@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
model_uid=credentials['model_uid'],
|
||||
api_key=credentials.get('api_key')
|
||||
)
|
||||
if 'completion_type' not in credentials:
|
||||
if 'chat' in extra_param.model_ability:
|
||||
@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
else:
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
model_uid=credentials['model_uid'],
|
||||
api_key=credentials.get('api_key')
|
||||
)
|
||||
|
||||
if 'chat' in extra_args.model_ability:
|
||||
@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
xinference_client = Client(
|
||||
base_url=credentials['server_url'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
|
||||
xinference_model = xinference_client.get_model(credentials['model_uid'])
|
||||
|
@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel):
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
base_url=credentials['server_url'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
base_url=credentials['server_url'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
||||
api_key = credentials.get('api_key')
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=server_url,
|
||||
model_uid=model_uid,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if extra_args.max_tokens:
|
||||
credentials['max_tokens'] = extra_args.max_tokens
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
client = Client(
|
||||
base_url=server_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
|
@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel):
|
||||
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
model_uid=credentials['model_uid'],
|
||||
api_key=credentials.get('api_key'),
|
||||
)
|
||||
|
||||
if 'text-to-audio' not in extra_param.model_ability:
|
||||
@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
try:
|
||||
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
|
||||
api_key = credentials.get('api_key')
|
||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
handle = RESTfulAudioModelHandle(
|
||||
credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
|
||||
)
|
||||
|
||||
model_support_voice = [x.get("value") for x in
|
||||
self.get_tts_model_voices(model=model, credentials=credentials)]
|
||||
|
@ -35,13 +35,13 @@ cache_lock = Lock()
|
||||
|
||||
class XinferenceHelper:
|
||||
@staticmethod
|
||||
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
||||
def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
|
||||
XinferenceHelper._clean_cache()
|
||||
with cache_lock:
|
||||
if model_uid not in cache:
|
||||
cache[model_uid] = {
|
||||
'expires': time() + 300,
|
||||
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
|
||||
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key)
|
||||
}
|
||||
return cache[model_uid]['value']
|
||||
|
||||
@ -56,7 +56,7 @@ class XinferenceHelper:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
||||
def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
|
||||
"""
|
||||
get xinference model extra parameter like model_format and model_handle_type
|
||||
"""
|
||||
@ -70,9 +70,10 @@ class XinferenceHelper:
|
||||
session = Session()
|
||||
session.mount('http://', HTTPAdapter(max_retries=3))
|
||||
session.mount('https://', HTTPAdapter(max_retries=3))
|
||||
headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
|
||||
try:
|
||||
response = session.get(url, timeout=10)
|
||||
response = session.get(url, headers=headers, timeout=10)
|
||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
|
||||
if response.status_code != 200:
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
@ -18,12 +19,9 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
@ -45,6 +43,7 @@ class ProviderManager:
|
||||
"""
|
||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
@ -117,6 +116,16 @@ class ProviderManager:
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
for provider_entity in provider_entities:
|
||||
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
continue
|
||||
|
||||
provider_name = provider_entity.provider
|
||||
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
||||
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
||||
@ -271,6 +280,24 @@ class ProviderManager:
|
||||
)
|
||||
)
|
||||
|
||||
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||
"""
|
||||
Get names of first model and its provider
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:return: provider name, model name
|
||||
"""
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
|
||||
# get available models from provider_configurations
|
||||
all_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=False
|
||||
)
|
||||
|
||||
return all_models[0].provider.provider, all_models[0].model
|
||||
|
||||
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
||||
-> TenantDefaultModel:
|
||||
"""
|
||||
@ -811,7 +838,7 @@ class ProviderManager:
|
||||
-> list[ModelSettings]:
|
||||
"""
|
||||
Convert to model settings.
|
||||
|
||||
:param provider_entity: provider entity
|
||||
:param provider_model_settings: provider model settings include enabled, load balancing enabled
|
||||
:param load_balancing_model_configs: load balancing model configs
|
||||
:return:
|
||||
|
@ -152,8 +152,27 @@ class PGVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# do not support bm25 search
|
||||
return []
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
|
||||
docs = []
|
||||
|
||||
for record in cur:
|
||||
metadata, text, score = record
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
|
@ -21,6 +21,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:param save_as: save as
|
||||
:return: the image message
|
||||
"""
|
||||
```
|
||||
@ -34,6 +35,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:param save_as: save as
|
||||
:return: the link message
|
||||
"""
|
||||
```
|
||||
@ -47,6 +49,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
create a text message
|
||||
|
||||
:param text: the text of the message
|
||||
:param save_as: save as
|
||||
:return: the text message
|
||||
"""
|
||||
```
|
||||
@ -63,6 +66,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:param meta: meta
|
||||
:param save_as: save as
|
||||
:return: the blob message
|
||||
"""
|
||||
```
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os.path
|
||||
|
||||
from core.helper.position_helper import get_position_map, sort_by_position_map
|
||||
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
|
||||
|
||||
@ -10,11 +10,11 @@ class BuiltinToolProviderSort:
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
if not cls._position:
|
||||
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
def name_func(provider: UserToolProvider) -> str:
|
||||
return provider.name
|
||||
|
||||
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||
|
||||
return sorted_providers
|
||||
return sorted_providers
|
||||
|
49
api/core/tools/provider/builtin/crossref/_assets/icon.svg
Normal file
49
api/core/tools/provider/builtin/crossref/_assets/icon.svg
Normal file
@ -0,0 +1,49 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 19.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 200 130.2" style="enable-background:new 0 0 200 130.2;" xml:space="preserve">
|
||||
<style type="text/css">
|
||||
.st0{fill:#3EB1C8;}
|
||||
.st1{fill:#D8D2C4;}
|
||||
.st2{fill:#4F5858;}
|
||||
.st3{fill:#FFC72C;}
|
||||
.st4{fill:#EF3340;}
|
||||
</style>
|
||||
<g>
|
||||
<polygon class="st0" points="111.8,95.5 111.8,66.8 135.4,59 177.2,73.3 "/>
|
||||
<polygon class="st1" points="153.6,36.8 111.8,51.2 135.4,59 177.2,44.6 "/>
|
||||
<polygon class="st2" points="135.4,59 177.2,44.6 177.2,73.3 "/>
|
||||
<polygon class="st3" points="177.2,0.3 177.2,29 153.6,36.8 111.8,22.5 "/>
|
||||
<polygon class="st4" points="153.6,36.8 111.8,51.2 111.8,22.5 "/>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="st2" d="M26.3,104.8c-0.5-3.7-4.1-6.5-8.1-6.5c-7.3,0-10.1,6.2-10.1,12.7c0,6.2,2.8,12.4,10.1,12.4
|
||||
c5,0,7.8-3.4,8.4-8.3h7.9c-0.8,9.2-7.2,15.2-16.3,15.2C6.8,130.2,0,121.7,0,111c0-11,6.8-19.6,18.2-19.6c8.2,0,15,4.8,16,13.3
|
||||
H26.3z"/>
|
||||
<path class="st2" d="M37.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
|
||||
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
|
||||
<path class="st2" d="M68.7,101.8c8.5,0,13.9,5.6,13.9,14.2c0,8.5-5.5,14.1-13.9,14.1c-8.4,0-13.9-5.6-13.9-14.1
|
||||
C54.9,107.4,60.3,101.8,68.7,101.8z M68.7,124.5c5,0,6.5-4.3,6.5-8.6c0-4.3-1.5-8.6-6.5-8.6c-5,0-6.5,4.3-6.5,8.6
|
||||
C62.2,120.2,63.8,124.5,68.7,124.5z"/>
|
||||
<path class="st2" d="M91.2,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2c-4.3-0.9-8.5-2.4-8.5-7.2
|
||||
c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5c0,2.6,4.2,3,8.4,4
|
||||
c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H91.2z"/>
|
||||
<path class="st2" d="M118.1,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2
|
||||
c-4.3-0.9-8.5-2.4-8.5-7.2c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5
|
||||
c0,2.6,4.2,3,8.4,4c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H118.1z"/>
|
||||
<path class="st2" d="M138.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
|
||||
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
|
||||
<path class="st2" d="M163.7,117.7c0.2,4.7,2.5,6.8,6.6,6.8c3,0,5.3-1.8,5.8-3.5h6.5c-2.1,6.3-6.5,9-12.6,9
|
||||
c-8.5,0-13.7-5.8-13.7-14.1c0-8,5.6-14.2,13.7-14.2c9.1,0,13.6,7.7,13,15.9H163.7z M175.7,113.1c-0.7-3.7-2.3-5.7-5.9-5.7
|
||||
c-4.7,0-6,3.6-6.1,5.7H175.7z"/>
|
||||
<path class="st2" d="M187.2,107.5h-4.4v-4.9h4.4v-2.1c0-4.7,3-8.2,9-8.2c1.3,0,2.6,0.2,3.9,0.2V98c-0.9-0.1-1.8-0.2-2.7-0.2
|
||||
c-2,0-2.8,0.8-2.8,3.1v1.6h5.1v4.9h-5.1v21.9h-7.4V107.5z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 3.0 KiB |
20
api/core/tools/provider/builtin/crossref/crossref.py
Normal file
20
api/core/tools/provider/builtin/crossref/crossref.py
Normal file
@ -0,0 +1,20 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CrossRefProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
CrossRefQueryDOITool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"doi": '10.1007/s00894-022-05373-8',
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
29
api/core/tools/provider/builtin/crossref/crossref.yaml
Normal file
29
api/core/tools/provider/builtin/crossref/crossref.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Sakura4036
|
||||
name: crossref
|
||||
label:
|
||||
en_US: CrossRef
|
||||
zh_Hans: CrossRef
|
||||
description:
|
||||
en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers.
|
||||
zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接,使得读者能够非常便捷地获取文献全文。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
mailto:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: email address
|
||||
zh_Hans: email地址
|
||||
pt_BR: email address
|
||||
placeholder:
|
||||
en_US: Please input your email address
|
||||
zh_Hans: 请输入你的email地址
|
||||
pt_BR: Please input your email address
|
||||
help:
|
||||
en_US: According to the requirements of Crossref, an email address is required
|
||||
zh_Hans: 根据Crossref的要求,需要提供一个邮箱地址
|
||||
pt_BR: According to the requirements of Crossref, an email address is required
|
||||
url: https://api.crossref.org/swagger-ui/index.html
|
25
api/core/tools/provider/builtin/crossref/tools/query_doi.py
Normal file
25
api/core/tools/provider/builtin/crossref/tools/query_doi.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CrossRefQueryDOITool(BuiltinTool):
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its DOI.
|
||||
"""
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
doi = tool_parameters.get('doi')
|
||||
if not doi:
|
||||
raise ToolParameterValidationError('doi is required.')
|
||||
# doc: https://github.com/CrossRef/rest-api-doc
|
||||
url = f"https://api.crossref.org/works/{doi}"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
response = response.json()
|
||||
message = response.get('message', {})
|
||||
|
||||
return self.create_json_message(message)
|
@ -0,0 +1,23 @@
|
||||
identity:
|
||||
name: crossref_query_doi
|
||||
author: Sakura4036
|
||||
label:
|
||||
en_US: CrossRef Query DOI
|
||||
zh_Hans: CrossRef DOI 查询
|
||||
pt_BR: CrossRef Query DOI
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for searching literature information using CrossRef by DOI.
|
||||
zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。
|
||||
pt_BR: A tool for searching literature information using CrossRef by DOI.
|
||||
llm: A tool for searching literature information using CrossRef by DOI.
|
||||
parameters:
|
||||
- name: doi
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: DOI
|
||||
zh_Hans: DOI
|
||||
pt_BR: DOI
|
||||
llm_description: DOI for searching in CrossRef
|
||||
form: llm
|
120
api/core/tools/provider/builtin/crossref/tools/query_title.py
Normal file
120
api/core/tools/provider/builtin/crossref/tools/query_title.py
Normal file
@ -0,0 +1,120 @@
|
||||
import time
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
def convert_time_str_to_seconds(time_str: str) -> int:
|
||||
"""
|
||||
Convert a time string to seconds.
|
||||
example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430
|
||||
"""
|
||||
time_str = time_str.lower().strip().replace(' ', '')
|
||||
seconds = 0
|
||||
if 'h' in time_str:
|
||||
hours, time_str = time_str.split('h')
|
||||
seconds += int(hours) * 3600
|
||||
if 'm' in time_str:
|
||||
minutes, time_str = time_str.split('m')
|
||||
seconds += int(minutes) * 60
|
||||
if 's' in time_str:
|
||||
seconds += int(time_str.replace('s', ''))
|
||||
return seconds
|
||||
|
||||
|
||||
class CrossRefQueryTitleAPI:
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its title.
|
||||
Crossref API doc: https://github.com/CrossRef/rest-api-doc
|
||||
"""
|
||||
query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}"
|
||||
rate_limit: int = 50
|
||||
rate_interval: float = 1
|
||||
max_limit: int = 1000
|
||||
|
||||
def __init__(self, mailto: str):
|
||||
self.mailto = mailto
|
||||
|
||||
def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
|
||||
"""
|
||||
Query the metadata of a publication using its title.
|
||||
:param query: the title of the publication
|
||||
:param rows: the number of results to return
|
||||
:param sort: the sort field
|
||||
:param order: the sort order
|
||||
:param fuzzy_query: whether to return all items that match the query
|
||||
"""
|
||||
url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto)
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
rate_limit = int(response.headers['x-ratelimit-limit'])
|
||||
# convert time string to seconds
|
||||
rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval'])
|
||||
|
||||
self.rate_limit = rate_limit
|
||||
self.rate_interval = rate_interval
|
||||
|
||||
response = response.json()
|
||||
if response['status'] != 'ok':
|
||||
return []
|
||||
|
||||
message = response['message']
|
||||
if fuzzy_query:
|
||||
# fuzzy query return all items
|
||||
return message['items']
|
||||
else:
|
||||
for paper in message['items']:
|
||||
title = paper['title'][0]
|
||||
if title.lower() != query.lower():
|
||||
continue
|
||||
return [paper]
|
||||
return []
|
||||
|
||||
def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
|
||||
"""
|
||||
Query the metadata of a publication using its title.
|
||||
:param query: the title of the publication
|
||||
:param rows: the number of results to return
|
||||
:param sort: the sort field
|
||||
:param order: the sort order
|
||||
:param fuzzy_query: whether to return all items that match the query
|
||||
"""
|
||||
rows = min(rows, self.max_limit)
|
||||
if rows > self.rate_limit:
|
||||
# query multiple times
|
||||
query_times = rows // self.rate_limit + 1
|
||||
results = []
|
||||
|
||||
for i in range(query_times):
|
||||
result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query)
|
||||
if fuzzy_query:
|
||||
results.extend(result)
|
||||
else:
|
||||
# fuzzy_query=False, only one result
|
||||
if result:
|
||||
return result
|
||||
time.sleep(self.rate_interval)
|
||||
return results
|
||||
else:
|
||||
# query once
|
||||
return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query)
|
||||
|
||||
|
||||
class CrossRefQueryTitleTool(BuiltinTool):
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its title.
|
||||
"""
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
query = tool_parameters.get('query')
|
||||
fuzzy_query = tool_parameters.get('fuzzy_query', False)
|
||||
rows = tool_parameters.get('rows', 3)
|
||||
sort = tool_parameters.get('sort', 'relevance')
|
||||
order = tool_parameters.get('order', 'desc')
|
||||
mailto = self.runtime.credentials['mailto']
|
||||
|
||||
result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query)
|
||||
|
||||
return [self.create_json_message(r) for r in result]
|
105
api/core/tools/provider/builtin/crossref/tools/query_title.yaml
Normal file
105
api/core/tools/provider/builtin/crossref/tools/query_title.yaml
Normal file
@ -0,0 +1,105 @@
|
||||
identity:
|
||||
name: crossref_query_title
|
||||
author: Sakura4036
|
||||
label:
|
||||
en_US: CrossRef Title Query
|
||||
zh_Hans: CrossRef 标题查询
|
||||
pt_BR: CrossRef Title Query
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for querying literature information using CrossRef by title.
|
||||
zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。
|
||||
pt_BR: A tool for querying literature information using CrossRef by title.
|
||||
llm: A tool for querying literature information using CrossRef by title.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: 标题
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: 标题
|
||||
human_description:
|
||||
en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
|
||||
zh_Hans: 用于搜索文献信息,有助于查找引用。包括标题,作者,ISSN和出版年份
|
||||
pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
|
||||
llm_description: key words for querying in Web of Science
|
||||
form: llm
|
||||
- name: fuzzy_query
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: Whether to fuzzy search
|
||||
zh_Hans: 是否模糊搜索
|
||||
pt_BR: Whether to fuzzy search
|
||||
human_description:
|
||||
en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
|
||||
zh_Hans: 用于选择搜索类型,模糊搜索返回更多结果,精确搜索返回1条结果或无
|
||||
pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
|
||||
form: form
|
||||
- name: limit
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: max query number
|
||||
zh_Hans: 最大搜索数
|
||||
pt_BR: max query number
|
||||
human_description:
|
||||
en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
|
||||
zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数)
|
||||
pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
|
||||
form: llm
|
||||
default: 50
|
||||
- name: sort
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: relevance
|
||||
label:
|
||||
en_US: relevance
|
||||
zh_Hans: 相关性
|
||||
pt_BR: relevance
|
||||
- value: published
|
||||
label:
|
||||
en_US: publication date
|
||||
zh_Hans: 出版日期
|
||||
pt_BR: publication date
|
||||
- value: references-count
|
||||
label:
|
||||
en_US: references-count
|
||||
zh_Hans: 引用次数
|
||||
pt_BR: references-count
|
||||
default: relevance
|
||||
label:
|
||||
en_US: sorting field
|
||||
zh_Hans: 排序字段
|
||||
pt_BR: sorting field
|
||||
human_description:
|
||||
en_US: Sorting of query results
|
||||
zh_Hans: 检索结果的排序字段
|
||||
pt_BR: Sorting of query results
|
||||
form: form
|
||||
- name: order
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: desc
|
||||
label:
|
||||
en_US: descending
|
||||
zh_Hans: 降序
|
||||
pt_BR: descending
|
||||
- value: asc
|
||||
label:
|
||||
en_US: ascending
|
||||
zh_Hans: 升序
|
||||
pt_BR: ascending
|
||||
default: desc
|
||||
label:
|
||||
en_US: Order
|
||||
zh_Hans: 排序
|
||||
pt_BR: Order
|
||||
human_description:
|
||||
en_US: Order of query results
|
||||
zh_Hans: 检索结果的排序方式
|
||||
pt_BR: Order of query results
|
||||
form: form
|
@ -29,6 +29,6 @@ class GitlabProvider(BuiltinToolProviderController):
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError((response.json()).get('message'))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e))
|
||||
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
@ -2,37 +2,37 @@ identity:
|
||||
author: Leo.Wang
|
||||
name: gitlab
|
||||
label:
|
||||
en_US: Gitlab
|
||||
zh_Hans: Gitlab
|
||||
en_US: GitLab
|
||||
zh_Hans: GitLab
|
||||
description:
|
||||
en_US: Gitlab plugin for commit
|
||||
zh_Hans: 用于获取Gitlab commit的插件
|
||||
en_US: GitLab plugin, API v4 only.
|
||||
zh_Hans: 用于获取GitLab内容的插件,目前仅支持 API v4。
|
||||
icon: gitlab.svg
|
||||
credentials_for_provider:
|
||||
access_tokens:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Gitlab access token
|
||||
zh_Hans: Gitlab access token
|
||||
en_US: GitLab access token
|
||||
zh_Hans: GitLab access token
|
||||
placeholder:
|
||||
en_US: Please input your Gitlab access token
|
||||
zh_Hans: 请输入你的 Gitlab access token
|
||||
en_US: Please input your GitLab access token
|
||||
zh_Hans: 请输入你的 GitLab access token
|
||||
help:
|
||||
en_US: Get your Gitlab access token from Gitlab
|
||||
zh_Hans: 从 Gitlab 获取您的 access token
|
||||
en_US: Get your GitLab access token from GitLab
|
||||
zh_Hans: 从 GitLab 获取您的 access token
|
||||
url: https://docs.gitlab.com/16.9/ee/api/oauth2.html
|
||||
site_url:
|
||||
type: text-input
|
||||
required: false
|
||||
default: 'https://gitlab.com'
|
||||
label:
|
||||
en_US: Gitlab site url
|
||||
zh_Hans: Gitlab site url
|
||||
en_US: GitLab site url
|
||||
zh_Hans: GitLab site url
|
||||
placeholder:
|
||||
en_US: Please input your Gitlab site url
|
||||
zh_Hans: 请输入你的 Gitlab site url
|
||||
en_US: Please input your GitLab site url
|
||||
zh_Hans: 请输入你的 GitLab site url
|
||||
help:
|
||||
en_US: Find your Gitlab url
|
||||
zh_Hans: 找到你的Gitlab url
|
||||
en_US: Find your GitLab url
|
||||
zh_Hans: 找到你的 GitLab url
|
||||
url: https://gitlab.com/help
|
||||
|
@ -18,6 +18,7 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
employee = tool_parameters.get('employee', '')
|
||||
start_time = tool_parameters.get('start_time', '')
|
||||
end_time = tool_parameters.get('end_time', '')
|
||||
change_type = tool_parameters.get('change_type', 'all')
|
||||
|
||||
if not project:
|
||||
return self.create_text_message('Project is required')
|
||||
@ -36,11 +37,11 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
site_url = 'https://gitlab.com'
|
||||
|
||||
# Get commit content
|
||||
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time)
|
||||
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
|
||||
|
||||
return self.create_text_message(json.dumps(result, ensure_ascii=False))
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]:
|
||||
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
@ -74,7 +75,7 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
|
||||
for commit in commits:
|
||||
commit_sha = commit['id']
|
||||
print(f"\tCommit SHA: {commit_sha}")
|
||||
author_name = commit['author_name']
|
||||
|
||||
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
|
||||
diff_response = requests.get(diff_url, headers=headers)
|
||||
@ -87,14 +88,23 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
removed_lines = diff['diff'].count('\n-')
|
||||
total_changes = added_lines + removed_lines
|
||||
|
||||
if total_changes > 1:
|
||||
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
|
||||
results.append({
|
||||
"project": project_name,
|
||||
"commit_sha": commit_sha,
|
||||
"diff": final_code
|
||||
})
|
||||
print(f"Commit code:{final_code}")
|
||||
if change_type == "new":
|
||||
if added_lines > 1:
|
||||
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
|
||||
results.append({
|
||||
"commit_sha": commit_sha,
|
||||
"author_name": author_name,
|
||||
"diff": final_code
|
||||
})
|
||||
else:
|
||||
if total_changes > 1:
|
||||
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')])
|
||||
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
||||
results.append({
|
||||
"commit_sha": commit_sha,
|
||||
"author_name": author_name,
|
||||
"diff": final_code_escaped
|
||||
})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
|
@ -2,24 +2,24 @@ identity:
|
||||
name: gitlab_commits
|
||||
author: Leo.Wang
|
||||
label:
|
||||
en_US: Gitlab Commits
|
||||
zh_Hans: Gitlab代码提交内容
|
||||
en_US: GitLab Commits
|
||||
zh_Hans: GitLab 提交内容查询
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for query gitlab commits. Input should be a exists username.
|
||||
zh_Hans: 一个用于查询gitlab代码提交记录的的工具,输入的内容应该是一个已存在的用户名或者项目名。
|
||||
llm: A tool for query gitlab commits. Input should be a exists username or project.
|
||||
en_US: A tool for query GitLab commits, Input should be a exists username or projec.
|
||||
zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。
|
||||
llm: A tool for query GitLab commits, Input should be a exists username or project.
|
||||
parameters:
|
||||
- name: employee
|
||||
- name: username
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: employee
|
||||
en_US: username
|
||||
zh_Hans: 员工用户名
|
||||
human_description:
|
||||
en_US: employee
|
||||
en_US: username
|
||||
zh_Hans: 员工用户名
|
||||
llm_description: employee for gitlab
|
||||
llm_description: User name for GitLab
|
||||
form: llm
|
||||
- name: project
|
||||
type: string
|
||||
@ -30,7 +30,7 @@ parameters:
|
||||
human_description:
|
||||
en_US: project
|
||||
zh_Hans: 项目名
|
||||
llm_description: project for gitlab
|
||||
llm_description: project for GitLab
|
||||
form: llm
|
||||
- name: start_time
|
||||
type: string
|
||||
@ -41,7 +41,7 @@ parameters:
|
||||
human_description:
|
||||
en_US: start_time
|
||||
zh_Hans: 开始时间
|
||||
llm_description: start_time for gitlab
|
||||
llm_description: Start time for GitLab
|
||||
form: llm
|
||||
- name: end_time
|
||||
type: string
|
||||
@ -52,5 +52,26 @@ parameters:
|
||||
human_description:
|
||||
en_US: end_time
|
||||
zh_Hans: 结束时间
|
||||
llm_description: end_time for gitlab
|
||||
llm_description: End time for GitLab
|
||||
form: llm
|
||||
- name: change_type
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: all
|
||||
label:
|
||||
en_US: all
|
||||
zh_Hans: 所有
|
||||
- value: new
|
||||
label:
|
||||
en_US: new
|
||||
zh_Hans: 新增
|
||||
default: all
|
||||
label:
|
||||
en_US: change_type
|
||||
zh_Hans: 变更类型
|
||||
human_description:
|
||||
en_US: change_type
|
||||
zh_Hans: 变更类型
|
||||
llm_description: Content change type for GitLab
|
||||
form: llm
|
||||
|
95
api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py
Normal file
95
api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py
Normal file
@ -0,0 +1,95 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GitlabFilesTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
project = tool_parameters.get('project', '')
|
||||
branch = tool_parameters.get('branch', '')
|
||||
path = tool_parameters.get('path', '')
|
||||
|
||||
|
||||
if not project:
|
||||
return self.create_text_message('Project is required')
|
||||
if not branch:
|
||||
return self.create_text_message('Branch is required')
|
||||
|
||||
if not path:
|
||||
return self.create_text_message('Path is required')
|
||||
|
||||
access_token = self.runtime.credentials.get('access_tokens')
|
||||
site_url = self.runtime.credentials.get('site_url')
|
||||
|
||||
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
|
||||
return self.create_text_message("Gitlab API Access Tokens is required.")
|
||||
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
|
||||
site_url = 'https://gitlab.com'
|
||||
|
||||
# Get project ID from project name
|
||||
project_id = self.get_project_id(site_url, access_token, project)
|
||||
if not project_id:
|
||||
return self.create_text_message(f"Project '{project}' not found.")
|
||||
|
||||
# Get commit content
|
||||
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
|
||||
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
|
||||
parts = path.split('/', 1)
|
||||
if len(parts) < 2:
|
||||
return None, None
|
||||
return parts[0], parts[1]
|
||||
|
||||
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
try:
|
||||
url = f"{site_url}/api/v4/projects?search={project_name}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
projects = response.json()
|
||||
for project in projects:
|
||||
if project['name'] == project_name:
|
||||
return project['id']
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching project ID from GitLab: {e}")
|
||||
return None
|
||||
|
||||
def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
# List files and directories in the given path
|
||||
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
|
||||
for item in items:
|
||||
item_path = item['path']
|
||||
if item['type'] == 'tree': # It's a directory
|
||||
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
|
||||
else: # It's a file
|
||||
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
results.append({
|
||||
"path": item_path,
|
||||
"branch": branch,
|
||||
"content": file_content
|
||||
})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
return results
|
@ -0,0 +1,45 @@
|
||||
identity:
|
||||
name: gitlab_files
|
||||
author: Leo.Wang
|
||||
label:
|
||||
en_US: GitLab Files
|
||||
zh_Hans: GitLab 文件获取
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for query GitLab files, Input should be branch and a exists file or directory path.
|
||||
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
|
||||
llm: A tool for query GitLab files, Input should be a exists file or directory path.
|
||||
parameters:
|
||||
- name: project
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: project
|
||||
zh_Hans: 项目
|
||||
human_description:
|
||||
en_US: project
|
||||
zh_Hans: 项目
|
||||
llm_description: Project for GitLab
|
||||
form: llm
|
||||
- name: branch
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: branch
|
||||
zh_Hans: 分支
|
||||
human_description:
|
||||
en_US: branch
|
||||
zh_Hans: 分支
|
||||
llm_description: Branch for GitLab
|
||||
form: llm
|
||||
- name: path
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: path
|
||||
zh_Hans: 文件路径
|
||||
human_description:
|
||||
en_US: path
|
||||
zh_Hans: 文件路径
|
||||
llm_description: File path for GitLab
|
||||
form: llm
|
43
api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py
Normal file
43
api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py
Normal file
@ -0,0 +1,43 @@
|
||||
from typing import Any
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class JinaTokenizerTool(BuiltinTool):
|
||||
_jina_tokenizer_endpoint = 'https://tokenize.jina.ai/'
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> ToolInvokeMessage:
|
||||
content = tool_parameters['content']
|
||||
body = {
|
||||
"content": content
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'):
|
||||
headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key')
|
||||
|
||||
if tool_parameters.get('return_chunks', False):
|
||||
body['return_chunks'] = True
|
||||
|
||||
if tool_parameters.get('return_tokens', False):
|
||||
body['return_tokens'] = True
|
||||
|
||||
if tokenizer := tool_parameters.get('tokenizer'):
|
||||
body['tokenizer'] = tokenizer
|
||||
|
||||
response = ssrf_proxy.post(
|
||||
self._jina_tokenizer_endpoint,
|
||||
headers=headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
return self.create_json_message(response.json())
|
@ -0,0 +1,70 @@
|
||||
identity:
|
||||
name: jina_tokenizer
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: JinaTokenizer
|
||||
description:
|
||||
human:
|
||||
en_US: Free API to tokenize text and segment long text into chunks.
|
||||
zh_Hans: 免费的API可以将文本tokenize,也可以将长文本分割成多个部分。
|
||||
llm: Free API to tokenize text and segment long text into chunks.
|
||||
parameters:
|
||||
- name: content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Content
|
||||
zh_Hans: 内容
|
||||
llm_description: the content which need to tokenize or segment
|
||||
form: llm
|
||||
- name: return_tokens
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Return the tokens
|
||||
zh_Hans: 是否返回tokens
|
||||
human_description:
|
||||
en_US: Return the tokens and their corresponding ids in the response.
|
||||
zh_Hans: 返回tokens及其对应的ids。
|
||||
form: form
|
||||
- name: return_chunks
|
||||
type: boolean
|
||||
label:
|
||||
en_US: Return the chunks
|
||||
zh_Hans: 是否分块
|
||||
human_description:
|
||||
en_US: Chunking the input into semantically meaningful segments while handling a wide variety of text types and edge cases based on common structural cues.
|
||||
zh_Hans: 将输入分块为具有语义意义的片段,同时根据常见的结构线索处理各种文本类型和边缘情况。
|
||||
form: form
|
||||
- name: tokenizer
|
||||
type: select
|
||||
options:
|
||||
- value: cl100k_base
|
||||
label:
|
||||
en_US: cl100k_base
|
||||
- value: o200k_base
|
||||
label:
|
||||
en_US: o200k_base
|
||||
- value: p50k_base
|
||||
label:
|
||||
en_US: p50k_base
|
||||
- value: r50k_base
|
||||
label:
|
||||
en_US: r50k_base
|
||||
- value: p50k_edit
|
||||
label:
|
||||
en_US: p50k_edit
|
||||
- value: gpt2
|
||||
label:
|
||||
en_US: gpt2
|
||||
label:
|
||||
en_US: Tokenizer
|
||||
human_description:
|
||||
en_US: |
|
||||
· cl100k_base --- gpt-4, gpt-3.5-turbo, gpt-3.5
|
||||
· o200k_base --- gpt-4o, gpt-4o-mini
|
||||
· p50k_base --- text-davinci-003, text-davinci-002
|
||||
· r50k_base --- text-davinci-001, text-curie-001
|
||||
· p50k_edit --- text-davinci-edit-001, code-davinci-edit-001
|
||||
· gpt2 --- gpt-2
|
||||
form: form
|
@ -0,0 +1,73 @@
|
||||
from novita_client import (
|
||||
Txt2ImgV3Embedding,
|
||||
Txt2ImgV3HiresFix,
|
||||
Txt2ImgV3LoRA,
|
||||
Txt2ImgV3Refiner,
|
||||
V3TaskImage,
|
||||
)
|
||||
|
||||
|
||||
class NovitaAiToolBase:
|
||||
def _extract_loras(self, loras_str: str):
|
||||
if not loras_str:
|
||||
return []
|
||||
|
||||
loras_ori_list = lora_str.strip().split(';')
|
||||
result_list = []
|
||||
for lora_str in loras_ori_list:
|
||||
lora_info = lora_str.strip().split(',')
|
||||
lora = Txt2ImgV3LoRA(
|
||||
model_name=lora_info[0].strip(),
|
||||
strength=float(lora_info[1]),
|
||||
)
|
||||
result_list.append(lora)
|
||||
|
||||
return result_list
|
||||
|
||||
def _extract_embeddings(self, embeddings_str: str):
|
||||
if not embeddings_str:
|
||||
return []
|
||||
|
||||
embeddings_ori_list = embeddings_str.strip().split(';')
|
||||
result_list = []
|
||||
for embedding_str in embeddings_ori_list:
|
||||
embedding = Txt2ImgV3Embedding(
|
||||
model_name=embedding_str.strip()
|
||||
)
|
||||
result_list.append(embedding)
|
||||
|
||||
return result_list
|
||||
|
||||
def _extract_hires_fix(self, hires_fix_str: str):
|
||||
hires_fix_info = hires_fix_str.strip().split(',')
|
||||
if 'upscaler' in hires_fix_info:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2]),
|
||||
upscaler=hires_fix_info[3].strip()
|
||||
)
|
||||
else:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2])
|
||||
)
|
||||
|
||||
return hires_fix
|
||||
|
||||
def _extract_refiner(self, switch_at: str):
|
||||
refiner = Txt2ImgV3Refiner(
|
||||
switch_at=float(switch_at)
|
||||
)
|
||||
return refiner
|
||||
|
||||
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
|
||||
"""
|
||||
is hit nsfw
|
||||
"""
|
||||
if image.nsfw_detection_result is None:
|
||||
return False
|
||||
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
|
||||
return True
|
||||
return False
|
@ -4,19 +4,15 @@ from typing import Any, Union
|
||||
|
||||
from novita_client import (
|
||||
NovitaClient,
|
||||
Txt2ImgV3Embedding,
|
||||
Txt2ImgV3HiresFix,
|
||||
Txt2ImgV3LoRA,
|
||||
Txt2ImgV3Refiner,
|
||||
V3TaskImage,
|
||||
)
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class NovitaAiTxt2ImgTool(BuiltinTool):
|
||||
class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
|
||||
|
||||
# process loras
|
||||
if 'loras' in res_parameters:
|
||||
loras_ori_list = res_parameters.get('loras').strip().split(';')
|
||||
locals_list = []
|
||||
for lora_str in loras_ori_list:
|
||||
lora_info = lora_str.strip().split(',')
|
||||
lora = Txt2ImgV3LoRA(
|
||||
model_name=lora_info[0].strip(),
|
||||
strength=float(lora_info[1]),
|
||||
)
|
||||
locals_list.append(lora)
|
||||
|
||||
res_parameters['loras'] = locals_list
|
||||
res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
|
||||
|
||||
# process embeddings
|
||||
if 'embeddings' in res_parameters:
|
||||
embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
|
||||
locals_list = []
|
||||
for embedding_str in embeddings_ori_list:
|
||||
embedding = Txt2ImgV3Embedding(
|
||||
model_name=embedding_str.strip()
|
||||
)
|
||||
locals_list.append(embedding)
|
||||
|
||||
res_parameters['embeddings'] = locals_list
|
||||
res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
|
||||
|
||||
# process hires_fix
|
||||
if 'hires_fix' in res_parameters:
|
||||
hires_fix_ori = res_parameters.get('hires_fix')
|
||||
hires_fix_info = hires_fix_ori.strip().split(',')
|
||||
if 'upscaler' in hires_fix_info:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2]),
|
||||
upscaler=hires_fix_info[3].strip()
|
||||
)
|
||||
else:
|
||||
hires_fix = Txt2ImgV3HiresFix(
|
||||
target_width=int(hires_fix_info[0]),
|
||||
target_height=int(hires_fix_info[1]),
|
||||
strength=float(hires_fix_info[2])
|
||||
)
|
||||
res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
|
||||
|
||||
res_parameters['hires_fix'] = hires_fix
|
||||
|
||||
if 'refiner_switch_at' in res_parameters:
|
||||
refiner = Txt2ImgV3Refiner(
|
||||
switch_at=float(res_parameters.get('refiner_switch_at'))
|
||||
)
|
||||
del res_parameters['refiner_switch_at']
|
||||
res_parameters['refiner'] = refiner
|
||||
# process refiner
|
||||
if 'refiner_switch_at' in res_parameters:
|
||||
res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
|
||||
del res_parameters['refiner_switch_at']
|
||||
|
||||
return res_parameters
|
||||
|
||||
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
|
||||
"""
|
||||
is hit nsfw
|
||||
"""
|
||||
if image.nsfw_detection_result is None:
|
||||
return False
|
||||
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
|
||||
return True
|
||||
return False
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -18,6 +18,13 @@ from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not app:
|
||||
raise ValueError('app not found')
|
||||
|
||||
|
||||
controller = WorkflowToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
'credentials_schema': {},
|
||||
'provider_id': db_provider.id or '',
|
||||
})
|
||||
|
||||
|
||||
# init tools
|
||||
|
||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||
@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
if variable:
|
||||
parameter_type = None
|
||||
options = None
|
||||
if variable.type in [
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.STRING
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.SELECT
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.SELECT
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.NUMBER
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.NUMBER
|
||||
else:
|
||||
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
||||
raise ValueError(f'unsupported variable type {variable.type}')
|
||||
|
||||
if variable.type == VariableEntity.Type.SELECT and variable.options:
|
||||
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
||||
|
||||
if variable.type == VariableEntityType.SELECT and variable.options:
|
||||
options = [
|
||||
ToolParameterOption(
|
||||
value=option,
|
||||
@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
|
||||
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
"""
|
||||
get tool by name
|
||||
@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
|
||||
return None
|
||||
|
@ -10,14 +10,11 @@ from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
@ -37,6 +31,7 @@ from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_builtin_providers = {}
|
||||
@ -106,7 +101,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@ -345,7 +340,7 @@ class ToolManager:
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
@ -413,6 +408,15 @@ class ToolManager:
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
@ -472,7 +476,7 @@ class ToolManager:
|
||||
|
||||
@classmethod
|
||||
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
@ -592,4 +596,5 @@ class ToolManager:
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
|
@ -7,14 +7,14 @@ from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
||||
SYSTEM_VARIABLE_NODE_ID = 'sys'
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
|
||||
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
|
||||
|
||||
class VariablePool(BaseModel):
|
||||
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
|
||||
description='User inputs',
|
||||
)
|
||||
|
||||
system_variables: Mapping[SystemVariable, Any] = Field(
|
||||
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
||||
description='System variables',
|
||||
)
|
||||
|
||||
@ -78,7 +78,7 @@ class VariablePool(BaseModel):
|
||||
None
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
|
||||
if value is None:
|
||||
return
|
||||
@ -105,13 +105,13 @@ class VariablePool(BaseModel):
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
@deprecated('This method is deprecated, use `get` instead.')
|
||||
@deprecated("This method is deprecated, use `get` instead.")
|
||||
def get_any(self, selector: Sequence[str], /) -> Any | None:
|
||||
"""
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
@ -126,7 +126,7 @@ class VariablePool(BaseModel):
|
||||
ValueError: If the selector is invalid.
|
||||
"""
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
raise ValueError("Invalid selector")
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
return value.to_object() if value else None
|
||||
|
@ -1,25 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SystemVariable(str, Enum):
|
||||
class SystemVariableKey(str, Enum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION_ID = 'conversation_id'
|
||||
USER_ID = 'user_id'
|
||||
DIALOGUE_COUNT = 'dialogue_count'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given system variable.
|
||||
|
||||
:param value: system variable value
|
||||
:return: system variable
|
||||
"""
|
||||
for system_variable in cls:
|
||||
if system_variable.value == value:
|
||||
return system_variable
|
||||
raise ValueError(f'invalid system variable value {value}')
|
||||
QUERY = "query"
|
||||
FILES = "files"
|
||||
CONVERSATION_ID = "conversation_id"
|
||||
USER_ID = "user_id"
|
||||
DIALOGUE_COUNT = "dialogue_count"
|
||||
|
@ -13,8 +13,8 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
MAX_NUMBER = dify_config.CODE_MAX_NUMBER
|
||||
MIN_NUMBER = dify_config.CODE_MIN_NUMBER
|
||||
MAX_PRECISION = 20
|
||||
MAX_DEPTH = 5
|
||||
MAX_PRECISION = dify_config.CODE_MAX_PRECISION
|
||||
MAX_DEPTH = dify_config.CODE_MAX_DEPTH
|
||||
MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
|
||||
MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH
|
||||
MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH
|
||||
@ -23,7 +23,7 @@ MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
_node_data_cls = CodeNodeData
|
||||
node_type = NodeType.CODE
|
||||
_node_type = NodeType.CODE
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
@ -316,8 +316,8 @@ class CodeNode(BaseNode):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: CodeNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
|
@ -25,7 +25,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
@ -110,7 +110,7 @@ class LLMNode(BaseNode):
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
|
||||
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
@ -370,7 +370,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
return inputs # type: ignore
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
|
||||
"""
|
||||
@ -382,7 +382,7 @@ class LLMNode(BaseNode):
|
||||
if not node_data.vision.enabled:
|
||||
return []
|
||||
|
||||
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
|
||||
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
@ -543,7 +543,7 @@ class LLMNode(BaseNode):
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
|
||||
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
@ -722,10 +722,10 @@ class LLMNode(BaseNode):
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
|
||||
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
@ -1,3 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
|
||||
"""
|
||||
Start Node Data
|
||||
"""
|
||||
variables: list[VariableEntity] = []
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
|
@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
@ -17,22 +18,22 @@ class StartNode(BaseNode):
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
|
||||
for var in self.graph_runtime_state.variable_pool.system_variables:
|
||||
cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var]
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=cleaned_inputs,
|
||||
outputs=cleaned_inputs
|
||||
inputs=node_inputs,
|
||||
outputs=node_inputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: StartNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
|
@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.segments import ArrayAnyVariable, parser
|
||||
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
@ -141,8 +141,8 @@ class ToolNode(BaseNode):
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable)
|
||||
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
|
||||
|
@ -1,109 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional, cast
|
||||
from .node import VariableAssignerNode
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.segments import SegmentType, Variable, factory
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAssignerNodeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
OVER_WRITE = 'over-write'
|
||||
APPEND = 'append'
|
||||
CLEAR = 'clear'
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = 'Variable Assigner'
|
||||
desc: Optional[str] = 'Assign a value to a variable'
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
data = cast(VariableAssignerData, self.node_data)
|
||||
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError('assigned variable not found')
|
||||
|
||||
match data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError('input value not found')
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={'value': updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
|
||||
|
||||
# Update conversation variable.
|
||||
# TODO: Find a better way to use the database.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError('conversation_id not found')
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
'value': income_value.to_object(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError('conversation variable not found in the database')
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return factory.build_segment([])
|
||||
case SegmentType.OBJECT:
|
||||
return factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return factory.build_segment('')
|
||||
case SegmentType.NUMBER:
|
||||
return factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
|
||||
__all__ = [
|
||||
'VariableAssignerNode',
|
||||
'VariableAssignerData',
|
||||
'WriteMode',
|
||||
]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user