Merge branch 'refs/heads/main' into feat/workflow-parallel-support

# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/start/start_node.py
#	api/core/workflow/nodes/variable_assigner/__init__.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
takatost 2024-08-21 16:59:23 +08:00
commit 35be41b337
252 changed files with 5991 additions and 2942 deletions

1
.gitignore vendored
View File

@ -178,3 +178,4 @@ pyrightconfig.json
api/.vscode
.idea/
.vscode

View File

@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

View File

@ -5,8 +5,8 @@
"name": "Python: Flask",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/api/.venv/bin/python",
"cwd": "${workspaceFolder}/api",
"python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}",
"envFile": ".env",
"module": "flask",
"justMyCode": true,
@ -18,15 +18,15 @@
"args": [
"run",
"--host=0.0.0.0",
"--port=5001",
"--port=5001"
]
},
{
"name": "Python: Celery",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/api/.venv/bin/python",
"cwd": "${workspaceFolder}/api",
"python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}",
"module": "celery",
"justMyCode": true,
"envFile": ".env",

View File

@ -37,6 +37,8 @@ class DifyConfig(
CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_DEPTH: int = 5
CODE_MAX_PRECISION: int = 20
CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30

View File

@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
default=False,
)
class WorkspaceConfig(BaseSettings):
"""
Workspace configs
@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description='The heads of model providers',
default='',
)
POSITION_PROVIDER_INCLUDES: str = Field(
description='The included model providers',
default='',
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description='The excluded model providers',
default='',
)
POSITION_TOOL_PINS: str = Field(
description='The heads of tools',
default='',
)
POSITION_TOOL_INCLUDES: str = Field(
description='The included tools',
default='',
)
POSITION_TOOL_EXCLUDES: str = Field(
description='The excluded tools',
default='',
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -466,6 +524,7 @@ class FeatureConfig(
UpdateConfig,
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.7.0',
default='0.7.1',
)
COMMIT_SHA: str = Field(

View File

@ -61,6 +61,7 @@ class AppListApi(Resource):
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
@ -94,6 +95,7 @@ class AppImportApi(Resource):
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
@ -167,6 +169,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
@ -208,6 +211,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()

View File

@ -154,6 +154,8 @@ class ChatConversationApi(Resource):
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
subquery = (
@ -225,7 +227,17 @@ class ChatConversationApi(Resource):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.order_by(Conversation.created_at.desc())
match args['sort_by']:
case 'created_at':
query = query.order_by(Conversation.created_at.asc())
case '-created_at':
query = query.order_by(Conversation.created_at.desc())
case 'updated_at':
query = query.order_by(Conversation.updated_at.asc())
case '-updated_at':
query = query.order_by(Conversation.updated_at.desc())
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(
query,

View File

@ -16,6 +16,7 @@ from models.model import Site
def parse_app_site_args():
parser = reqparse.RequestParser()
parser.add_argument('title', type=str, required=False, location='json')
parser.add_argument('icon_type', type=str, required=False, location='json')
parser.add_argument('icon', type=str, required=False, location='json')
parser.add_argument('icon_background', type=str, required=False, location='json')
parser.add_argument('description', type=str, required=False, location='json')
@ -53,6 +54,7 @@ class AppSite(Resource):
for attr_name in [
'title',
'icon_type',
'icon',
'icon_background',
'description',

View File

@ -460,6 +460,7 @@ class ConvertToWorkflowApi(Resource):
if request.data:
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()

View File

@ -573,13 +573,13 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -25,6 +25,8 @@ class ConversationApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
try:
@ -33,7 +35,8 @@ class ConversationApi(Resource):
user=end_user,
last_id=args['last_id'],
limit=args['limit'],
invoke_from=InvokeFrom.SERVICE_API
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args['sort_by']
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -53,19 +53,22 @@ class SegmentApi(DatasetApiResource):
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
if args['segments'] is not None:
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
else:
return {"error": "Segemtns is required"}, 400
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""

View File

@ -26,6 +26,8 @@ class ConversationListApi(WebApiResource):
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
pinned = None
@ -40,6 +42,7 @@ class ConversationListApi(WebApiResource):
limit=args['limit'],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args['sort_by']
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -6,6 +6,7 @@ from configs import dify_config
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from libs.helper import AppIconUrlField
from models.account import TenantStatus
from models.model import Site
from services.feature_service import FeatureService
@ -28,8 +29,10 @@ class AppSiteApi(WebApiResource):
'title': fields.String,
'chat_color_theme': fields.String,
'chat_color_theme_inverted': fields.Boolean,
'icon_type': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'icon_url': AppIconUrlField,
'description': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,

View File

@ -64,15 +64,19 @@ class BaseAgentRunner(AppRunner):
"""
Agent runner
:param tenant_id: tenant id
:param application_generate_entity: application generate entity
:param conversation: conversation
:param app_config: app generate entity
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
:param prompt_messages: prompt messages
:param variables_pool: variables pool
:param db_variables: db variables
:param model_instance: model instance
"""
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
@ -445,7 +449,7 @@ class BaseAgentRunner(AppRunner):
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
tool_responses = { tool: agent_thought.observation for tool in tools }
tool_responses = dict.fromkeys(tools, agent_thought.observation)
for tool in tools:
# generate a uuid for tool call

View File

@ -292,6 +292,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
handle invoke action
:param action: action
:param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta
"""
# action is tool call, invoke tool

View File

@ -93,7 +93,7 @@ class DatasetConfigManager:
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs["reranking_mode"],
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
)
)

View File

@ -1,6 +1,6 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
:param config: model config args
"""
external_data_variables = []
variables = []
variable_entities = []
# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
)
# variables and external_data_tools
for variable in config.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
if 'config' not in val:
for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if 'config' not in variable:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
variable=variable['variable'],
type=variable['type'],
config=variable['config']
)
)
elif typ in [
VariableEntity.Type.TEXT_INPUT.value,
VariableEntity.Type.PARAGRAPH.value,
VariableEntity.Type.NUMBER.value,
elif variable_type in [
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]:
variables.append(
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=VariableEntity.Type.value_of(typ),
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
max_length=variable[typ].get('max_length'),
default=variable[typ].get('default'),
)
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
type=variable_type,
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
)
)
return variables, external_data_variables
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
config=config
)
return config, ["external_data_tools"]
return config, ["external_data_tools"]

View File

@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'
@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')
variable: str
label: str
description: Optional[str] = None
type: Type
type: VariableEntityType
required: bool = False
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None
@property
def name(self) -> str:
return self.variable
class ExternalDataVariableEntity(BaseModel):
"""
@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
"""
Workflow UI Based App Config Entity.
"""
workflow_id: str
workflow_id: str

View File

@ -23,6 +23,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
@ -67,8 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id', ''), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# parse files
files = args['files'] if args.get('files') else []
@ -225,6 +228,62 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
@ -296,7 +355,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@ -47,7 +47,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created
from extensions.ext_database import db
@ -69,7 +69,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
@ -102,10 +102,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation = conversation
self._message = message
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id,
}
self._task_state = WorkflowTaskState()
@ -312,7 +312,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -321,7 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
class BaseAppGenerator:
@ -9,29 +9,29 @@ class BaseAppGenerator:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name)
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form')
raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
if (
var.type
in (
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.SELECT,
VariableEntity.Type.PARAGRAPH,
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
)
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
@ -39,14 +39,14 @@ class BaseAppGenerator:
else:
return int(user_input_value)
except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
return user_input_value

View File

@ -256,6 +256,7 @@ class AppRunner:
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param stream: stream
:param agent: agent
:return:
"""
if not stream:
@ -278,6 +279,7 @@ class AppRunner:
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:return:
"""
queue_manager.publish(
@ -293,6 +295,7 @@ class AppRunner:
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:return:
"""
model = None

View File

@ -1,6 +1,7 @@
import json
import logging
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union
from sqlalchemy import and_
@ -36,17 +37,17 @@ logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response(
self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
@ -138,6 +139,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
"""
Initialize generate records
:param application_generate_entity: application generate entity
:conversation conversation
:return:
"""
app_config = application_generate_entity.app_config
@ -192,6 +194,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
message = Message(
app_id=app_config.app_id,

View File

@ -13,7 +13,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, EndUser
@ -79,14 +79,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
@ -98,7 +98,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,

View File

@ -41,7 +41,9 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -64,7 +66,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
@ -88,8 +90,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow
self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id
}
self._task_state = WorkflowTaskState()

View File

@ -4,7 +4,7 @@ from enum import Enum
from threading import Lock
from typing import Literal, Optional
from httpx import get, post
from httpx import Timeout, get, post
from pydantic import BaseModel
from yarl import URL
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
CODE_EXECUTION_TIMEOUT = (10, 60)
CODE_EXECUTION_TIMEOUT = Timeout(connect=10, write=10, read=60, pool=None)
class CodeExecutionException(Exception):
pass
@ -116,7 +116,7 @@ class CodeExecutor:
if response.data.error:
raise CodeExecutionException(response.data.error)
return response.data.stdout
return response.data.stdout or ''
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:

View File

@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
def get_default_code(cls) -> str:
return dedent(
"""
def main(arg1: int, arg2: int) -> dict:
def main(arg1: str, arg2: str) -> dict:
return {
"result": arg1 + arg2,
}

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from collections.abc import Callable
from typing import Any
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file
@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
return {name: index for index, name in enumerate(positions)}
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for tools from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
)
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for providers from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
)
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
"""
Pin the items in the pin list to the beginning of the position map.
Overall logic: exclude > include > pin
:param position_map: the position map to be sorted and filtered
:param pin_list: the list of pins to be put at the beginning
:return: the sorted position map
"""
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
# Add pins to position map
position_map = {name: idx for idx, name in enumerate(pin_list)}
# Add remaining positions to position map
start_idx = len(position_map)
for name in positions:
if name not in position_map:
position_map[name] = start_idx
start_idx += 1
return position_map
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
) -> bool:
"""
Chcek if the object should be filtered out.
Overall logic: exclude > include > pin
:param include_set: the set of names to be included
:param exclude_set: the set of names to be excluded
:param name_func: the function to get the name of the object
:param data: the data to be filtered
:return: True if the object should be filtered out, False otherwise
"""
if not data:
return False
if not include_set and not exclude_set:
return False
name = name_func(data)
if name in exclude_set: # exclude_set is prioritized
return True
if include_set and name not in include_set: # filter out only if include_set is not empty
return True
return False
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],

View File

@ -700,6 +700,7 @@ class IndexingRunner:
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.error: None,
}
)

View File

@ -271,9 +271,8 @@ class ModelInstance:
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param user: unique user id
:param voice: model timbre
:param streaming: output is streaming
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
@ -369,6 +368,15 @@ class ModelManager:
return ModelInstance(provider_model_bundle, model)
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Return first provider and the first model in the provider
:param tenant_id: tenant id
:param model_type: model type
:return: provider name, model name
"""
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
@ -401,6 +409,10 @@ class LBModelManager:
managed_credentials: Optional[dict] = None) -> None:
"""
Load balancing model manager
:param tenant_id: tenant_id
:param provider: provider
:param model_type: model_type
:param model: model name
:param load_balancing_configs: all load balancing configurations
:param managed_credentials: credentials if load balancing configuration name is __inherit__
"""
@ -499,7 +511,6 @@ class LBModelManager:
config.id
)
res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
return res

View File

@ -151,9 +151,9 @@ class AIModel(ABC):
os.path.join(provider_model_type_path, model_schema_yaml)
for model_schema_yaml in os.listdir(provider_model_type_path)
if not model_schema_yaml.startswith('__')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
]
# get _position.yaml file path

View File

@ -185,7 +185,7 @@ if you are not sure about the structure.
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
@ -249,10 +249,10 @@ if you are not sure about the structure.
prompt_messages=prompt_messages,
input_generator=new_generator()
)
return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
@ -310,7 +310,7 @@ if you are not sure about the structure.
)
)
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]:
"""
@ -470,7 +470,7 @@ if you are not sure about the structure.
:return: full response or stream response chunk generator result
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -792,6 +792,13 @@ if you are not sure about the structure.
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be string.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
elif parameter_rule.type == ParameterType.TEXT:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be text.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")

View File

@ -70,7 +70,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
# max font is 4096,there is 3500 limit for each request
# max length is 4096 characters, there is 3500 limit for each request
max_length = 3500
if len(content_text) > max_length:
sentences = self._split_text_into_sentences(content_text, max_length=max_length)

View File

@ -6,7 +6,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@ -234,7 +234,7 @@ class ModelProviderFactory:
]
# get _position.yaml file path
position_map = get_position_map(model_providers_path)
position_map = get_provider_position_map(model_providers_path)
# traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []

View File

@ -84,7 +84,8 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)

View File

@ -31,6 +31,14 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: Base URL, 如https://api.moonshot.cn/v1
en_US: Base URL, e.g. https://api.moonshot.cn/v1
model_credential_schema:
model:
label:

View File

@ -37,6 +37,9 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '0.15'
output: '0.60'

View File

@ -0,0 +1,44 @@
model: gpt-4o-2024-08-06
label:
zh_Hans: gpt-4o-2024-08-06
en_US: gpt-4o-2024-08-06
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,4 @@
model: netease-youdao/bce-reranker-base_v1
model_type: rerank
model_properties:
context_size: 512

View File

@ -0,0 +1,4 @@
model: BAAI/bge-reranker-v2-m3
model_type: rerank
model_properties:
context_size: 8192

View File

@ -0,0 +1,87 @@
from typing import Optional
import httpx
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class SiliconflowRerankModel(RerankModel):
def _invoke(self, model: str, credentials: dict, query: str, docs: list[str],
score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
if len(docs) == 0:
return RerankResult(model=model, docs=[])
base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1')
if base_url.endswith('/'):
base_url = base_url[:-1]
try:
response = httpx.post(
base_url + '/rerank',
json={
"model": model,
"query": query,
"documents": docs,
"top_n": top_n,
"return_documents": True
},
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results['results']:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}

View File

@ -12,10 +12,11 @@ help:
en_US: Get your API Key from SiliconFlow
zh_Hans: 从 SiliconFlow 获取 API Key
url:
en_US: https://cloud.siliconflow.cn/keys
en_US: https://cloud.siliconflow.cn/account/ak
supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
configurate_methods:
- predefined-model

View File

@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
get_model_config,
get_v2_req_params,
)
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
logger = logging.getLogger(__name__)
@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
-> LLMResult | Generator:
client = MaaSClient.from_credential(credentials)
req_params = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).copy()
if credentials.get('context_size'):
req_params['max_prompt_tokens'] = credentials.get('context_size')
if credentials.get('max_tokens'):
req_params['max_new_tokens'] = credentials.get('max_tokens')
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
req_params = get_v2_req_params(credentials, model_parameters, stop)
extra_model_kwargs = {}
if tools:
extra_model_kwargs['tools'] = [
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
]
resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream:
@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
"""
used to define customizable model schema
"""
max_tokens = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
if credentials.get('max_tokens'):
max_tokens = int(credentials.get('max_tokens'))
model_config = get_model_config(credentials)
rules = [
ParameterRule(
name='temperature',
@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='presence_penalty',
type=ParameterType.FLOAT,
use_template='presence_penalty',
label={
'en_US': 'Presence Penalty',
'zh_Hans': '存在惩罚',
},
label=I18nObject(
en_US='Presence Penalty',
zh_Hans= '存在惩罚',
),
min=-2.0,
max=2.0,
),
@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='frequency_penalty',
type=ParameterType.FLOAT,
use_template='frequency_penalty',
label={
'en_US': 'Frequency Penalty',
'zh_Hans': '频率惩罚',
},
label=I18nObject(
en_US= 'Frequency Penalty',
zh_Hans= '频率惩罚',
),
min=-2.0,
max=2.0,
),
@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=max_tokens,
max=model_config.properties.max_tokens,
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
@ -266,17 +242,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
),
]
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('mode'):
model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
model_features = ModelConfigs.get(
credentials['base_model_name'], {}).get('features', [])
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
entity = AIModelEntity(
model=model,
label=I18nObject(
@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_type=ModelType.LLM,
model_properties=model_properties,
parameter_rules=rules,
features=model_features,
features=model_config.features,
)
return entity

View File

@ -1,181 +1,123 @@
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature
ModelConfigs = {
'Doubao-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Skylark2-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4000,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [],
},
'Llama3-8B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Llama3-70B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-8k': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 16384,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 65536,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B-Fin': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Mistral-7B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 2048,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
}
class ModelProperties(BaseModel):
context_size: int
max_tokens: int
mode: LLMMode
class ModelConfig(BaseModel):
properties: ModelProperties
features: list[ModelFeature]
configs: dict[str, ModelConfig] = {
'Doubao-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Skylark2-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
features=[]
),
'Llama3-8B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[]
),
'Llama3-70B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-8k': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
features=[]
),
'GLM3-130B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'GLM3-130B-Fin': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Mistral-7B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
features=[]
)
}
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = configs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_tokens=int(credentials.get('max_tokens', 0)),
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
),
features=[]
)
return model_configs
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None=None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
req_params['max_prompt_tokens'] = model_configs.properties.context_size
req_params['max_new_tokens'] = model_configs.properties.max_tokens
# model parameters
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params

View File

@ -1,9 +1,27 @@
from pydantic import BaseModel
class ModelProperties(BaseModel):
context_size: int
max_chunks: int
class ModelConfig(BaseModel):
properties: ModelProperties
ModelConfigs = {
'Doubao-embedding': {
'req_params': {},
'model_properties': {
'context_size': 4096,
'max_chunks': 1,
}
},
'Doubao-embedding': ModelConfig(
properties=ModelProperties(context_size=4096, max_chunks=1)
),
}
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = ModelConfigs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_chunks=int(credentials.get('max_chunks', 0)),
)
)
return model_configs

View File

@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
"""
generate custom model entities from credentials
"""
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
if credentials.get('max_chunks'):
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
credentials.get('max_chunks', 4096))
model_config = get_model_config(credentials)
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),

View File

@ -0,0 +1,198 @@
from datetime import datetime, timedelta
from threading import Lock
from requests import post
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
@staticmethod
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class _CommonWenxin:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en',
'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh',
'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
credentials_kwargs = {
"api_key": credentials['api_key'],
"secret_key": credentials['secret_key']
}
return credentials_kwargs
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token

View File

@ -1,102 +1,17 @@
from collections.abc import Generator
from datetime import datetime, timedelta
from enum import Enum
from json import dumps, loads
from threading import Lock
from typing import Any, Union
from requests import Response, post
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
# map api_key to access_token
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class ErnieMessage:
class Role(Enum):
@ -120,51 +35,7 @@ class ErnieMessage:
self.content = content
self.role = role
class ErnieBotModel:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
class ErnieBotModel(_CommonWenxin):
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
@ -199,51 +70,6 @@ class ErnieBotModel:
return self._handle_chat_stream_generate_response(resp)
return self._handle_chat_generate_response(resp)
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
return [ErnieMessage(message.content, message.role) for message in messages]

View File

@ -1,17 +0,0 @@
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

View File

@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
BadRequestError,
InsufficientAccountBalance,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken._get_access_token(api_key, secret_key)
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}
return invoke_error_mapping()

View File

@ -0,0 +1,9 @@
model: bge-large-en
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: bge-large-zh
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: embedding-v1
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: tao-8k
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 1
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,184 @@
import time
from abc import abstractmethod
from collections.abc import Mapping
from json import dumps
from typing import Any, Optional
import numpy as np
from requests import Response, post
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
invoke_error_mapping,
)
class TextEmbedding:
@abstractmethod
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
raise NotImplementedError
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
access_token = self._get_access_token()
url = f'{self.api_bases[model]}?access_token={access_token}'
body = self._build_embed_request_body(model, texts, user)
headers = {
'Content-Type': 'application/json',
}
resp = post(url, data=dumps(body), headers=headers)
if resp.status_code != 200:
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
return self._handle_embed_response(model, resp)
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
if len(texts) == 0:
raise BadRequestError('The number of texts should not be zero.')
body = {
'input': texts,
'user_id': user,
}
return body
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
data = response.json()
if 'error_code' in data:
code = data['error_code']
msg = data['error_msg']
# raise error
self._handle_error(code, msg)
embeddings = [v['embedding'] for v in data['data']]
_usage = data['usage']
tokens = _usage['prompt_tokens']
total_tokens = _usage['total_tokens']
return embeddings, tokens, total_tokens
class WenxinTextEmbeddingModel(TextEmbeddingModel):
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
return WenxinTextEmbedding(api_key, secret_key)
def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
secret_key = credentials['secret_key']
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
user = user if user else 'ErnieBotDefault'
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
used_total_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
model,
inputs[i: i + max_chunks],
user)
used_tokens += _used_tokens
used_total_tokens += _total_used_tokens
batched_embeddings += embeddings_batch
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
return TextEmbeddingResult(
model=model,
embeddings=batched_embeddings,
usage=usage,
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0
total_num_tokens = 0
for text in texts:
total_num_tokens += self._get_num_tokens_by_gpt2(text)
return total_num_tokens
def validate_credentials(self, model: str, credentials: Mapping) -> None:
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return invoke_error_mapping()
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=total_tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage

View File

@ -17,6 +17,7 @@ help:
en_US: https://cloud.baidu.com/wenxin.html
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:

View File

@ -0,0 +1,57 @@
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

View File

@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
tools=tools, stop=stop, stream=stream, user=user,
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
)
)
@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
)
if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability:
@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
else:
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
)
if 'chat' in extra_args.model_ability:
@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
xinference_client = Client(
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)
xinference_model = xinference_client.get_model(credentials['model_uid'])

View File

@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel):
# initialize client
client = Client(
base_url=credentials['server_url']
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])

View File

@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
# initialize client
client = Client(
base_url=credentials['server_url']
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])

View File

@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
api_key = credentials.get('api_key')
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=server_url,
model_uid=model_uid,
api_key=api_key,
)
if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
if server_url.endswith('/'):
server_url = server_url[:-1]
client = Client(base_url=server_url)
client = Client(
base_url=server_url,
api_key=api_key,
)
try:
handle = client.get_model(model_uid=model_uid)

View File

@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel):
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
)
if 'text-to-audio' not in extra_param.model_ability:
@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel):
credentials['server_url'] = credentials['server_url'][:-1]
try:
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
api_key = credentials.get('api_key')
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
handle = RESTfulAudioModelHandle(
credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
)
model_support_voice = [x.get("value") for x in
self.get_tts_model_voices(model=model, credentials=credentials)]

View File

@ -35,13 +35,13 @@ cache_lock = Lock()
class XinferenceHelper:
@staticmethod
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
XinferenceHelper._clean_cache()
with cache_lock:
if model_uid not in cache:
cache[model_uid] = {
'expires': time() + 300,
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key)
}
return cache[model_uid]['value']
@ -56,7 +56,7 @@ class XinferenceHelper:
pass
@staticmethod
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
"""
get xinference model extra parameter like model_format and model_handle_type
"""
@ -70,9 +70,10 @@ class XinferenceHelper:
session = Session()
session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3))
headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
try:
response = session.get(url, timeout=10)
response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200:

View File

@ -5,6 +5,7 @@ from typing import Optional
from sqlalchemy.exc import IntegrityError
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
@ -18,12 +19,9 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.helper.position_helper import is_filtered
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider
from extensions.ext_database import db
@ -45,6 +43,7 @@ class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self) -> None:
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
@ -117,6 +116,16 @@ class ProviderManager:
# Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
continue
provider_name = provider_entity.provider
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
@ -271,6 +280,24 @@ class ProviderManager:
)
)
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Get names of first model and its provider
:param tenant_id: workspace id
:param model_type: model type
:return: provider name, model name
"""
provider_configurations = self.get_configurations(tenant_id)
# get available models from provider_configurations
all_models = provider_configurations.get_models(
model_type=model_type,
only_active=False
)
return all_models[0].provider.provider, all_models[0].model
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
-> TenantDefaultModel:
"""
@ -811,7 +838,7 @@ class ProviderManager:
-> list[ModelSettings]:
"""
Convert to model settings.
:param provider_entity: provider entity
:param provider_model_settings: provider model settings include enabled, load balancing enabled
:param load_balancing_model_configs: load balancing model configs
:return:

View File

@ -152,8 +152,27 @@ class PGVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# do not support bm25 search
return []
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
for record in cur:
metadata, text, score = record
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def delete(self) -> None:
with self._get_cursor() as cur:

View File

@ -21,6 +21,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
create an image message
:param image: the url of the image
:param save_as: save as
:return: the image message
"""
```
@ -34,6 +35,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
create a link message
:param link: the url of the link
:param save_as: save as
:return: the link message
"""
```
@ -47,6 +49,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
create a text message
:param text: the text of the message
:param save_as: save as
:return: the text message
"""
```
@ -63,6 +66,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
create a blob message
:param blob: the blob
:param meta: meta
:param save_as: save as
:return: the blob message
"""
```

View File

@ -1,6 +1,6 @@
import os.path
from core.helper.position_helper import get_position_map, sort_by_position_map
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
from core.tools.entities.api_entities import UserToolProvider
@ -10,11 +10,11 @@ class BuiltinToolProviderSort:
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
def name_func(provider: UserToolProvider) -> str:
return provider.name
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers
return sorted_providers

View File

@ -0,0 +1,49 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 19.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 200 130.2" style="enable-background:new 0 0 200 130.2;" xml:space="preserve">
<style type="text/css">
.st0{fill:#3EB1C8;}
.st1{fill:#D8D2C4;}
.st2{fill:#4F5858;}
.st3{fill:#FFC72C;}
.st4{fill:#EF3340;}
</style>
<g>
<polygon class="st0" points="111.8,95.5 111.8,66.8 135.4,59 177.2,73.3 "/>
<polygon class="st1" points="153.6,36.8 111.8,51.2 135.4,59 177.2,44.6 "/>
<polygon class="st2" points="135.4,59 177.2,44.6 177.2,73.3 "/>
<polygon class="st3" points="177.2,0.3 177.2,29 153.6,36.8 111.8,22.5 "/>
<polygon class="st4" points="153.6,36.8 111.8,51.2 111.8,22.5 "/>
<g>
<g>
<g>
<g>
<path class="st2" d="M26.3,104.8c-0.5-3.7-4.1-6.5-8.1-6.5c-7.3,0-10.1,6.2-10.1,12.7c0,6.2,2.8,12.4,10.1,12.4
c5,0,7.8-3.4,8.4-8.3h7.9c-0.8,9.2-7.2,15.2-16.3,15.2C6.8,130.2,0,121.7,0,111c0-11,6.8-19.6,18.2-19.6c8.2,0,15,4.8,16,13.3
H26.3z"/>
<path class="st2" d="M37.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
<path class="st2" d="M68.7,101.8c8.5,0,13.9,5.6,13.9,14.2c0,8.5-5.5,14.1-13.9,14.1c-8.4,0-13.9-5.6-13.9-14.1
C54.9,107.4,60.3,101.8,68.7,101.8z M68.7,124.5c5,0,6.5-4.3,6.5-8.6c0-4.3-1.5-8.6-6.5-8.6c-5,0-6.5,4.3-6.5,8.6
C62.2,120.2,63.8,124.5,68.7,124.5z"/>
<path class="st2" d="M91.2,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2c-4.3-0.9-8.5-2.4-8.5-7.2
c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5c0,2.6,4.2,3,8.4,4
c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H91.2z"/>
<path class="st2" d="M118.1,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2
c-4.3-0.9-8.5-2.4-8.5-7.2c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5
c0,2.6,4.2,3,8.4,4c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H118.1z"/>
<path class="st2" d="M138.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
<path class="st2" d="M163.7,117.7c0.2,4.7,2.5,6.8,6.6,6.8c3,0,5.3-1.8,5.8-3.5h6.5c-2.1,6.3-6.5,9-12.6,9
c-8.5,0-13.7-5.8-13.7-14.1c0-8,5.6-14.2,13.7-14.2c9.1,0,13.6,7.7,13,15.9H163.7z M175.7,113.1c-0.7-3.7-2.3-5.7-5.9-5.7
c-4.7,0-6,3.6-6.1,5.7H175.7z"/>
<path class="st2" d="M187.2,107.5h-4.4v-4.9h4.4v-2.1c0-4.7,3-8.2,9-8.2c1.3,0,2.6,0.2,3.9,0.2V98c-0.9-0.1-1.8-0.2-2.7-0.2
c-2,0-2.8,0.8-2.8,3.1v1.6h5.1v4.9h-5.1v21.9h-7.4V107.5z"/>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 3.0 KiB

View File

@ -0,0 +1,20 @@
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class CrossRefProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
CrossRefQueryDOITool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"doi": '10.1007/s00894-022-05373-8',
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,29 @@
identity:
author: Sakura4036
name: crossref
label:
en_US: CrossRef
zh_Hans: CrossRef
description:
en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers.
zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接使得读者能够非常便捷地获取文献全文。
icon: icon.svg
tags:
- search
credentials_for_provider:
mailto:
type: text-input
required: true
label:
en_US: email address
zh_Hans: email地址
pt_BR: email address
placeholder:
en_US: Please input your email address
zh_Hans: 请输入你的email地址
pt_BR: Please input your email address
help:
en_US: According to the requirements of Crossref, an email address is required
zh_Hans: 根据Crossref的要求需要提供一个邮箱地址
pt_BR: According to the requirements of Crossref, an email address is required
url: https://api.crossref.org/swagger-ui/index.html

View File

@ -0,0 +1,25 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError
from core.tools.tool.builtin_tool import BuiltinTool
class CrossRefQueryDOITool(BuiltinTool):
"""
Tool for querying the metadata of a publication using its DOI.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
doi = tool_parameters.get('doi')
if not doi:
raise ToolParameterValidationError('doi is required.')
# doc: https://github.com/CrossRef/rest-api-doc
url = f"https://api.crossref.org/works/{doi}"
response = requests.get(url)
response.raise_for_status()
response = response.json()
message = response.get('message', {})
return self.create_json_message(message)

View File

@ -0,0 +1,23 @@
identity:
name: crossref_query_doi
author: Sakura4036
label:
en_US: CrossRef Query DOI
zh_Hans: CrossRef DOI 查询
pt_BR: CrossRef Query DOI
description:
human:
en_US: A tool for searching literature information using CrossRef by DOI.
zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。
pt_BR: A tool for searching literature information using CrossRef by DOI.
llm: A tool for searching literature information using CrossRef by DOI.
parameters:
- name: doi
type: string
required: true
label:
en_US: DOI
zh_Hans: DOI
pt_BR: DOI
llm_description: DOI for searching in CrossRef
form: llm

View File

@ -0,0 +1,120 @@
import time
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
def convert_time_str_to_seconds(time_str: str) -> int:
"""
Convert a time string to seconds.
example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430
"""
time_str = time_str.lower().strip().replace(' ', '')
seconds = 0
if 'h' in time_str:
hours, time_str = time_str.split('h')
seconds += int(hours) * 3600
if 'm' in time_str:
minutes, time_str = time_str.split('m')
seconds += int(minutes) * 60
if 's' in time_str:
seconds += int(time_str.replace('s', ''))
return seconds
class CrossRefQueryTitleAPI:
"""
Tool for querying the metadata of a publication using its title.
Crossref API doc: https://github.com/CrossRef/rest-api-doc
"""
query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}"
rate_limit: int = 50
rate_interval: float = 1
max_limit: int = 1000
def __init__(self, mailto: str):
self.mailto = mailto
def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
"""
Query the metadata of a publication using its title.
:param query: the title of the publication
:param rows: the number of results to return
:param sort: the sort field
:param order: the sort order
:param fuzzy_query: whether to return all items that match the query
"""
url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto)
response = requests.get(url)
response.raise_for_status()
rate_limit = int(response.headers['x-ratelimit-limit'])
# convert time string to seconds
rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval'])
self.rate_limit = rate_limit
self.rate_interval = rate_interval
response = response.json()
if response['status'] != 'ok':
return []
message = response['message']
if fuzzy_query:
# fuzzy query return all items
return message['items']
else:
for paper in message['items']:
title = paper['title'][0]
if title.lower() != query.lower():
continue
return [paper]
return []
def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
"""
Query the metadata of a publication using its title.
:param query: the title of the publication
:param rows: the number of results to return
:param sort: the sort field
:param order: the sort order
:param fuzzy_query: whether to return all items that match the query
"""
rows = min(rows, self.max_limit)
if rows > self.rate_limit:
# query multiple times
query_times = rows // self.rate_limit + 1
results = []
for i in range(query_times):
result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query)
if fuzzy_query:
results.extend(result)
else:
# fuzzy_query=False, only one result
if result:
return result
time.sleep(self.rate_interval)
return results
else:
# query once
return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query)
class CrossRefQueryTitleTool(BuiltinTool):
"""
Tool for querying the metadata of a publication using its title.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
query = tool_parameters.get('query')
fuzzy_query = tool_parameters.get('fuzzy_query', False)
rows = tool_parameters.get('rows', 3)
sort = tool_parameters.get('sort', 'relevance')
order = tool_parameters.get('order', 'desc')
mailto = self.runtime.credentials['mailto']
result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query)
return [self.create_json_message(r) for r in result]

View File

@ -0,0 +1,105 @@
identity:
name: crossref_query_title
author: Sakura4036
label:
en_US: CrossRef Title Query
zh_Hans: CrossRef 标题查询
pt_BR: CrossRef Title Query
description:
human:
en_US: A tool for querying literature information using CrossRef by title.
zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。
pt_BR: A tool for querying literature information using CrossRef by title.
llm: A tool for querying literature information using CrossRef by title.
parameters:
- name: query
type: string
required: true
label:
en_US: 标题
zh_Hans: 查询语句
pt_BR: 标题
human_description:
en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
zh_Hans: 用于搜索文献信息有助于查找引用。包括标题作者ISSN和出版年份
pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
llm_description: key words for querying in Web of Science
form: llm
- name: fuzzy_query
type: boolean
default: false
label:
en_US: Whether to fuzzy search
zh_Hans: 是否模糊搜索
pt_BR: Whether to fuzzy search
human_description:
en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
zh_Hans: 用于选择搜索类型模糊搜索返回更多结果精确搜索返回1条结果或无
pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
form: form
- name: limit
type: number
required: false
label:
en_US: max query number
zh_Hans: 最大搜索数
pt_BR: max query number
human_description:
en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数)
pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
form: llm
default: 50
- name: sort
type: select
required: true
options:
- value: relevance
label:
en_US: relevance
zh_Hans: 相关性
pt_BR: relevance
- value: published
label:
en_US: publication date
zh_Hans: 出版日期
pt_BR: publication date
- value: references-count
label:
en_US: references-count
zh_Hans: 引用次数
pt_BR: references-count
default: relevance
label:
en_US: sorting field
zh_Hans: 排序字段
pt_BR: sorting field
human_description:
en_US: Sorting of query results
zh_Hans: 检索结果的排序字段
pt_BR: Sorting of query results
form: form
- name: order
type: select
required: true
options:
- value: desc
label:
en_US: descending
zh_Hans: 降序
pt_BR: descending
- value: asc
label:
en_US: ascending
zh_Hans: 升序
pt_BR: ascending
default: desc
label:
en_US: Order
zh_Hans: 排序
pt_BR: Order
human_description:
en_US: Order of query results
zh_Hans: 检索结果的排序方式
pt_BR: Order of query results
form: form

View File

@ -29,6 +29,6 @@ class GitlabProvider(BuiltinToolProviderController):
if response.status_code != 200:
raise ToolProviderCredentialValidationError((response.json()).get('message'))
except Exception as e:
raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e))
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -2,37 +2,37 @@ identity:
author: Leo.Wang
name: gitlab
label:
en_US: Gitlab
zh_Hans: Gitlab
en_US: GitLab
zh_Hans: GitLab
description:
en_US: Gitlab plugin for commit
zh_Hans: 用于获取Gitlab commit的插件
en_US: GitLab plugin, API v4 only.
zh_Hans: 用于获取GitLab内容的插件目前仅支持 API v4。
icon: gitlab.svg
credentials_for_provider:
access_tokens:
type: secret-input
required: true
label:
en_US: Gitlab access token
zh_Hans: Gitlab access token
en_US: GitLab access token
zh_Hans: GitLab access token
placeholder:
en_US: Please input your Gitlab access token
zh_Hans: 请输入你的 Gitlab access token
en_US: Please input your GitLab access token
zh_Hans: 请输入你的 GitLab access token
help:
en_US: Get your Gitlab access token from Gitlab
zh_Hans: 从 Gitlab 获取您的 access token
en_US: Get your GitLab access token from GitLab
zh_Hans: 从 GitLab 获取您的 access token
url: https://docs.gitlab.com/16.9/ee/api/oauth2.html
site_url:
type: text-input
required: false
default: 'https://gitlab.com'
label:
en_US: Gitlab site url
zh_Hans: Gitlab site url
en_US: GitLab site url
zh_Hans: GitLab site url
placeholder:
en_US: Please input your Gitlab site url
zh_Hans: 请输入你的 Gitlab site url
en_US: Please input your GitLab site url
zh_Hans: 请输入你的 GitLab site url
help:
en_US: Find your Gitlab url
zh_Hans: 找到你的Gitlab url
en_US: Find your GitLab url
zh_Hans: 找到你的 GitLab url
url: https://gitlab.com/help

View File

@ -18,6 +18,7 @@ class GitlabCommitsTool(BuiltinTool):
employee = tool_parameters.get('employee', '')
start_time = tool_parameters.get('start_time', '')
end_time = tool_parameters.get('end_time', '')
change_type = tool_parameters.get('change_type', 'all')
if not project:
return self.create_text_message('Project is required')
@ -36,11 +37,11 @@ class GitlabCommitsTool(BuiltinTool):
site_url = 'https://gitlab.com'
# Get commit content
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time)
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
return self.create_text_message(json.dumps(result, ensure_ascii=False))
return [self.create_json_message(item) for item in result]
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]:
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
@ -74,7 +75,7 @@ class GitlabCommitsTool(BuiltinTool):
for commit in commits:
commit_sha = commit['id']
print(f"\tCommit SHA: {commit_sha}")
author_name = commit['author_name']
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
diff_response = requests.get(diff_url, headers=headers)
@ -87,14 +88,23 @@ class GitlabCommitsTool(BuiltinTool):
removed_lines = diff['diff'].count('\n-')
total_changes = added_lines + removed_lines
if total_changes > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
results.append({
"project": project_name,
"commit_sha": commit_sha,
"diff": final_code
})
print(f"Commit code:{final_code}")
if change_type == "new":
if added_lines > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
results.append({
"commit_sha": commit_sha,
"author_name": author_name,
"diff": final_code
})
else:
if total_changes > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')])
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
results.append({
"commit_sha": commit_sha,
"author_name": author_name,
"diff": final_code_escaped
})
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")

View File

@ -2,24 +2,24 @@ identity:
name: gitlab_commits
author: Leo.Wang
label:
en_US: Gitlab Commits
zh_Hans: Gitlab代码提交内容
en_US: GitLab Commits
zh_Hans: GitLab 提交内容查询
description:
human:
en_US: A tool for query gitlab commits. Input should be a exists username.
zh_Hans: 一个用于查询gitlab代码提交记录的的工具,输入的内容应该是一个已存在的用户名或者项目名。
llm: A tool for query gitlab commits. Input should be a exists username or project.
en_US: A tool for query GitLab commits, Input should be a exists username or projec.
zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。
llm: A tool for query GitLab commits, Input should be a exists username or project.
parameters:
- name: employee
- name: username
type: string
required: false
label:
en_US: employee
en_US: username
zh_Hans: 员工用户名
human_description:
en_US: employee
en_US: username
zh_Hans: 员工用户名
llm_description: employee for gitlab
llm_description: User name for GitLab
form: llm
- name: project
type: string
@ -30,7 +30,7 @@ parameters:
human_description:
en_US: project
zh_Hans: 项目名
llm_description: project for gitlab
llm_description: project for GitLab
form: llm
- name: start_time
type: string
@ -41,7 +41,7 @@ parameters:
human_description:
en_US: start_time
zh_Hans: 开始时间
llm_description: start_time for gitlab
llm_description: Start time for GitLab
form: llm
- name: end_time
type: string
@ -52,5 +52,26 @@ parameters:
human_description:
en_US: end_time
zh_Hans: 结束时间
llm_description: end_time for gitlab
llm_description: End time for GitLab
form: llm
- name: change_type
type: select
required: false
options:
- value: all
label:
en_US: all
zh_Hans: 所有
- value: new
label:
en_US: new
zh_Hans: 新增
default: all
label:
en_US: change_type
zh_Hans: 变更类型
human_description:
en_US: change_type
zh_Hans: 变更类型
llm_description: Content change type for GitLab
form: llm

View File

@ -0,0 +1,95 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GitlabFilesTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get('project', '')
branch = tool_parameters.get('branch', '')
path = tool_parameters.get('path', '')
if not project:
return self.create_text_message('Project is required')
if not branch:
return self.create_text_message('Branch is required')
if not path:
return self.create_text_message('Path is required')
access_token = self.runtime.credentials.get('access_tokens')
site_url = self.runtime.credentials.get('site_url')
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
return self.create_text_message("Gitlab API Access Tokens is required.")
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
site_url = 'https://gitlab.com'
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, project)
if not project_id:
return self.create_text_message(f"Project '{project}' not found.")
# Get commit content
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
return [self.create_json_message(item) for item in result]
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
parts = path.split('/', 1)
if len(parts) < 2:
return None, None
return parts[0], parts[1]
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
headers = {"PRIVATE-TOKEN": access_token}
try:
url = f"{site_url}/api/v4/projects?search={project_name}"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
for project in projects:
if project['name'] == project_name:
return project['id']
except requests.RequestException as e:
print(f"Error fetching project ID from GitLab: {e}")
return None
def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# List files and directories in the given path
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
response = requests.get(url, headers=headers)
response.raise_for_status()
items = response.json()
for item in items:
item_path = item['path']
if item['type'] == 'tree': # It's a directory
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
else: # It's a file
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({
"path": item_path,
"branch": branch,
"content": file_content
})
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results

View File

@ -0,0 +1,45 @@
identity:
name: gitlab_files
author: Leo.Wang
label:
en_US: GitLab Files
zh_Hans: GitLab 文件获取
description:
human:
en_US: A tool for query GitLab files, Input should be branch and a exists file or directory path.
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
llm: A tool for query GitLab files, Input should be a exists file or directory path.
parameters:
- name: project
type: string
required: true
label:
en_US: project
zh_Hans: 项目
human_description:
en_US: project
zh_Hans: 项目
llm_description: Project for GitLab
form: llm
- name: branch
type: string
required: true
label:
en_US: branch
zh_Hans: 分支
human_description:
en_US: branch
zh_Hans: 分支
llm_description: Branch for GitLab
form: llm
- name: path
type: string
required: true
label:
en_US: path
zh_Hans: 文件路径
human_description:
en_US: path
zh_Hans: 文件路径
llm_description: File path for GitLab
form: llm

View File

@ -0,0 +1,43 @@
from typing import Any
from core.helper import ssrf_proxy
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class JinaTokenizerTool(BuiltinTool):
_jina_tokenizer_endpoint = 'https://tokenize.jina.ai/'
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> ToolInvokeMessage:
content = tool_parameters['content']
body = {
"content": content
}
headers = {
'Content-Type': 'application/json'
}
if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'):
headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key')
if tool_parameters.get('return_chunks', False):
body['return_chunks'] = True
if tool_parameters.get('return_tokens', False):
body['return_tokens'] = True
if tokenizer := tool_parameters.get('tokenizer'):
body['tokenizer'] = tokenizer
response = ssrf_proxy.post(
self._jina_tokenizer_endpoint,
headers=headers,
json=body,
)
return self.create_json_message(response.json())

View File

@ -0,0 +1,70 @@
identity:
name: jina_tokenizer
author: hjlarry
label:
en_US: JinaTokenizer
description:
human:
en_US: Free API to tokenize text and segment long text into chunks.
zh_Hans: 免费的API可以将文本tokenize也可以将长文本分割成多个部分。
llm: Free API to tokenize text and segment long text into chunks.
parameters:
- name: content
type: string
required: true
label:
en_US: Content
zh_Hans: 内容
llm_description: the content which need to tokenize or segment
form: llm
- name: return_tokens
type: boolean
required: false
label:
en_US: Return the tokens
zh_Hans: 是否返回tokens
human_description:
en_US: Return the tokens and their corresponding ids in the response.
zh_Hans: 返回tokens及其对应的ids。
form: form
- name: return_chunks
type: boolean
label:
en_US: Return the chunks
zh_Hans: 是否分块
human_description:
en_US: Chunking the input into semantically meaningful segments while handling a wide variety of text types and edge cases based on common structural cues.
zh_Hans: 将输入分块为具有语义意义的片段,同时根据常见的结构线索处理各种文本类型和边缘情况。
form: form
- name: tokenizer
type: select
options:
- value: cl100k_base
label:
en_US: cl100k_base
- value: o200k_base
label:
en_US: o200k_base
- value: p50k_base
label:
en_US: p50k_base
- value: r50k_base
label:
en_US: r50k_base
- value: p50k_edit
label:
en_US: p50k_edit
- value: gpt2
label:
en_US: gpt2
label:
en_US: Tokenizer
human_description:
en_US: |
· cl100k_base --- gpt-4, gpt-3.5-turbo, gpt-3.5
· o200k_base --- gpt-4o, gpt-4o-mini
· p50k_base --- text-davinci-003, text-davinci-002
· r50k_base --- text-davinci-001, text-curie-001
· p50k_edit --- text-davinci-edit-001, code-davinci-edit-001
· gpt2 --- gpt-2
form: form

View File

@ -0,0 +1,73 @@
from novita_client import (
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
Txt2ImgV3Refiner,
V3TaskImage,
)
class NovitaAiToolBase:
def _extract_loras(self, loras_str: str):
if not loras_str:
return []
loras_ori_list = lora_str.strip().split(';')
result_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
)
result_list.append(lora)
return result_list
def _extract_embeddings(self, embeddings_str: str):
if not embeddings_str:
return []
embeddings_ori_list = embeddings_str.strip().split(';')
result_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
result_list.append(embedding)
return result_list
def _extract_hires_fix(self, hires_fix_str: str):
hires_fix_info = hires_fix_str.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
)
return hires_fix
def _extract_refiner(self, switch_at: str):
refiner = Txt2ImgV3Refiner(
switch_at=float(switch_at)
)
return refiner
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
return True
return False

View File

@ -4,19 +4,15 @@ from typing import Any, Union
from novita_client import (
NovitaClient,
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
Txt2ImgV3Refiner,
V3TaskImage,
)
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
from core.tools.tool.builtin_tool import BuiltinTool
class NovitaAiTxt2ImgTool(BuiltinTool):
class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
# process loras
if 'loras' in res_parameters:
loras_ori_list = res_parameters.get('loras').strip().split(';')
locals_list = []
for lora_str in loras_ori_list:
lora_info = lora_str.strip().split(',')
lora = Txt2ImgV3LoRA(
model_name=lora_info[0].strip(),
strength=float(lora_info[1]),
)
locals_list.append(lora)
res_parameters['loras'] = locals_list
res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
# process embeddings
if 'embeddings' in res_parameters:
embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
locals_list = []
for embedding_str in embeddings_ori_list:
embedding = Txt2ImgV3Embedding(
model_name=embedding_str.strip()
)
locals_list.append(embedding)
res_parameters['embeddings'] = locals_list
res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
# process hires_fix
if 'hires_fix' in res_parameters:
hires_fix_ori = res_parameters.get('hires_fix')
hires_fix_info = hires_fix_ori.strip().split(',')
if 'upscaler' in hires_fix_info:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2]),
upscaler=hires_fix_info[3].strip()
)
else:
hires_fix = Txt2ImgV3HiresFix(
target_width=int(hires_fix_info[0]),
target_height=int(hires_fix_info[1]),
strength=float(hires_fix_info[2])
)
res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
res_parameters['hires_fix'] = hires_fix
if 'refiner_switch_at' in res_parameters:
refiner = Txt2ImgV3Refiner(
switch_at=float(res_parameters.get('refiner_switch_at'))
)
del res_parameters['refiner_switch_at']
res_parameters['refiner'] = refiner
# process refiner
if 'refiner_switch_at' in res_parameters:
res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
del res_parameters['refiner_switch_at']
return res_parameters
def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
"""
is hit nsfw
"""
if image.nsfw_detection_result is None:
return False
if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
return True
return False

View File

@ -1,6 +1,6 @@
from typing import Optional
from core.app.app_config.entities import VariableEntity
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
@ -18,6 +18,13 @@ from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
}
class WorkflowToolProviderController(ToolProviderController):
provider_id: str
@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
if not app:
raise ValueError('app not found')
controller = WorkflowToolProviderController(**{
'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
'credentials_schema': {},
'provider_id': db_provider.id or '',
})
# init tools
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
"""
get db provider tool
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
if variable:
parameter_type = None
options = None
if variable.type in [
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH,
]:
parameter_type = ToolParameter.ToolParameterType.STRING
elif variable.type in [
VariableEntity.Type.SELECT
]:
parameter_type = ToolParameter.ToolParameterType.SELECT
elif variable.type in [
VariableEntity.Type.NUMBER
]:
parameter_type = ToolParameter.ToolParameterType.NUMBER
else:
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f'unsupported variable type {variable.type}')
if variable.type == VariableEntity.Type.SELECT and variable.options:
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
if variable.type == VariableEntityType.SELECT and variable.options:
options = [
ToolParameterOption(
value=option,
@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
"""
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers:
return []
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
"""
get tool by name
@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
for tool in self.tools:
if tool.identity.name == tool_name:
return tool
return None

View File

@ -10,14 +10,11 @@ from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
)
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@ -37,6 +31,7 @@ from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ToolManager:
_builtin_provider_lock = Lock()
_builtin_providers = {}
@ -106,7 +101,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]:
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
@ -345,7 +340,7 @@ class ToolManager:
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider
@ -413,6 +408,15 @@ class ToolManager:
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.identity.name),
@ -472,7 +476,7 @@ class ToolManager:
@classmethod
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
ApiToolProviderController, dict[str, Any]]:
ApiToolProviderController, dict[str, Any]]:
"""
get the api provider
@ -592,4 +596,5 @@ class ToolManager:
else:
raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_builtin_providers_cache()

View File

@ -7,14 +7,14 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool(BaseModel):
@ -32,7 +32,7 @@ class VariablePool(BaseModel):
description='User inputs',
)
system_variables: Mapping[SystemVariable, Any] = Field(
system_variables: Mapping[SystemVariableKey, Any] = Field(
description='System variables',
)
@ -78,7 +78,7 @@ class VariablePool(BaseModel):
None
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
if value is None:
return
@ -105,13 +105,13 @@ class VariablePool(BaseModel):
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
return value
@deprecated('This method is deprecated, use `get` instead.')
@deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Retrieves the value from the variable pool based on the given selector.
@ -126,7 +126,7 @@ class VariablePool(BaseModel):
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None

View File

@ -1,25 +1,13 @@
from enum import Enum
class SystemVariable(str, Enum):
class SystemVariableKey(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
QUERY = "query"
FILES = "files"
CONVERSATION_ID = "conversation_id"
USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"

View File

@ -13,8 +13,8 @@ from models.workflow import WorkflowNodeExecutionStatus
MAX_NUMBER = dify_config.CODE_MAX_NUMBER
MIN_NUMBER = dify_config.CODE_MIN_NUMBER
MAX_PRECISION = 20
MAX_DEPTH = 5
MAX_PRECISION = dify_config.CODE_MAX_PRECISION
MAX_DEPTH = dify_config.CODE_MAX_DEPTH
MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH
MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH
@ -23,7 +23,7 @@ MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH
class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
node_type = NodeType.CODE
_node_type = NodeType.CODE
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -316,8 +316,8 @@ class CodeNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData
) -> Mapping[str, Sequence[str]]:

View File

@ -25,7 +25,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
@ -110,7 +110,7 @@ class LLMNode(BaseNode):
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
@ -370,7 +370,7 @@ class LLMNode(BaseNode):
inputs[variable_selector.variable] = variable_value
return inputs # type: ignore
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
"""
@ -382,7 +382,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled:
return []
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
if not files:
return []
@ -543,7 +543,7 @@ class LLMNode(BaseNode):
return None
# get conversation id
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None:
return None
@ -722,10 +722,10 @@ class LLMNode(BaseNode):
variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
if node_data.prompt_config:
enable_jinja = False

View File

@ -1,3 +1,7 @@
from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
"""
Start Node Data
"""
variables: list[VariableEntity] = []
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -17,22 +18,22 @@ class StartNode(BaseNode):
Run node
:return:
"""
# Get cleaned inputs
cleaned_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
for var in self.graph_runtime_state.variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var]
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
inputs=node_inputs,
outputs=node_inputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: StartNodeData
) -> Mapping[str, Sequence[str]]:

View File

@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast
from core.app.segments import ArrayAnyVariable, parser
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -141,8 +141,8 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
assert isinstance(variable, ArrayAnyVariable)
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\

View File

@ -1,109 +1,8 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional, cast
from .node import VariableAssignerNode
from .node_data import VariableAssignerData, WriteMode
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
class VariableAssignerNodeError(Exception):
pass
class WriteMode(str, Enum):
OVER_WRITE = 'over-write'
APPEND = 'append'
CLEAR = 'clear'
class VariableAssignerData(BaseNodeData):
title: str = 'Variable Assigner'
desc: Optional[str] = 'Assign a value to a variable'
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]
class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={'value': updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
case _:
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
# Update conversation variable.
# TODO: Find a better way to use the database.
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'value': income_value.to_object(),
},
)
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableAssignerNodeError('conversation variable not found in the database')
row.data = variable.model_dump_json()
session.commit()
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment('')
case SegmentType.NUMBER:
return factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f'unsupported variable type: {t}')
__all__ = [
'VariableAssignerNode',
'VariableAssignerData',
'WriteMode',
]

Some files were not shown because too many files have changed in this diff Show More