From 7cb75cb2e7415ba4aa71b54b8b3dc87de2715753 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Wed, 24 Jan 2024 20:14:45 +0800 Subject: [PATCH] feat: add tool labels (#2178) --- api/controllers/service_api/app/message.py | 1 + api/core/app_runner/generate_task_pipeline.py | 4 ++- api/core/features/assistant_base_runner.py | 16 ++++++++++ api/core/tools/tool_manager.py | 25 ++++++++++++++- api/fields/conversation_fields.py | 3 +- api/fields/message_fields.py | 1 + ...a5a70d_add_tool_labels_to_agent_thought.py | 32 +++++++++++++++++++ api/models/model.py | 11 +++++++ 8 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index aca995f993..c90a1fb1e2 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -44,6 +44,7 @@ class MessageListApi(AppApiResource): 'position': fields.Integer, 'thought': fields.String, 'tool': fields.String, + 'tool_labels': fields.Raw, 'tool_input': fields.String, 'created_at': TimestampField, 'observation': fields.String, diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 58a99a52c2..4bb8b1dd70 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.prompt_template import PromptTemplateParser from events.message_event import message_was_created @@ -281,7 +282,7 @@ class GenerateTaskPipeline: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): - agent_thought = ( + agent_thought: MessageAgentThought = ( db.session.query(MessageAgentThought) .filter(MessageAgentThought.id == event.agent_thought_id) .first() @@ -298,6 +299,7 @@ class GenerateTaskPipeline: 'thought': agent_thought.thought, 'observation': agent_thought.observation, 'tool': agent_thought.tool, + 'tool_labels': agent_thought.tool_labels, 'tool_input': agent_thought.tool_input, 'created_at': int(self._message.created_at.timestamp()), 'message_files': agent_thought.files diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 2a896d6fbd..32f7f1d49f 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner): message_chain_id=None, thought='', tool=tool_name, + tool_labels_str='{}', tool_input=tool_input, message=message, message_token=0, @@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner): agent_thought.tokens = llm_usage.total_tokens agent_thought.total_price = llm_usage.total_price + # check if tool labels is not empty + labels = agent_thought.tool_labels or {} + tools = agent_thought.tool.split(';') if agent_thought.tool else [] + for tool in tools: + if not tool: + continue + if tool not in labels: + tool_label = ToolManager.get_tool_label(tool) + if tool_label: + labels[tool] = tool_label.to_dict() + else: + labels[tool] = {'en_US': tool, 'zh_Hans': tool} + + agent_thought.tool_labels_str = json.dumps(labels) + db.session.commit() def get_history_prompt_messages(self) -> List[PromptMessage]: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0dbb1c1116..01c79e3dd8 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -31,6 +31,7 @@ import mimetypes logger = logging.getLogger(__name__) _builtin_providers = {} +_builtin_tools_labels = {} class ToolManager: @staticmethod @@ -233,7 +234,7 @@ class ToolManager: if len(_builtin_providers) > 0: return list(_builtin_providers.values()) - builtin_providers = [] + builtin_providers: List[BuiltinToolProviderController] = [] for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): if provider.startswith('__'): continue @@ -264,8 +265,30 @@ class ToolManager: # cache the builtin providers for provider in builtin_providers: _builtin_providers[provider.identity.name] = provider + for tool in provider.get_tools(): + _builtin_tools_labels[tool.identity.name] = tool.identity.label + return builtin_providers + @staticmethod + def get_tool_label(tool_name: str) -> Union[I18nObject, None]: + """ + get the tool label + + :param tool_name: the name of the tool + + :return: the label of the tool + """ + global _builtin_tools_labels + if len(_builtin_tools_labels) == 0: + # init the builtin providers + ToolManager.list_builtin_providers() + + if tool_name not in _builtin_tools_labels: + return None + + return _builtin_tools_labels[tool_name] + @staticmethod def user_list_providers( user_id: str, diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 557f047a95..f7298933c1 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -49,10 +49,11 @@ agent_thought_fields = { 'position': fields.Integer, 'thought': fields.String, 'tool': fields.String, + 'tool_labels': fields.Raw, 'tool_input': fields.String, 'created_at': TimestampField, 'observation': fields.String, - 'files': fields.List(fields.String) + 'files': fields.List(fields.String), } message_detail_fields = { diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 397b9795b8..59a029e3e2 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -36,6 +36,7 @@ agent_thought_fields = { 'position': fields.Integer, 'thought': fields.String, 'tool': fields.String, + 'tool_labels': fields.Raw, 'tool_input': fields.String, 'created_at': TimestampField, 'observation': fields.String, diff --git a/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py b/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py new file mode 100644 index 0000000000..2031dc2ad3 --- /dev/null +++ b/api/migrations/versions/380c6aa5a70d_add_tool_labels_to_agent_thought.py @@ -0,0 +1,32 @@ +"""add tool labels to agent thought + +Revision ID: 380c6aa5a70d +Revises: dfb3b7f477da +Create Date: 2024-01-24 10:58:15.644445 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '380c6aa5a70d' +down_revision = 'dfb3b7f477da' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('tool_labels_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_column('tool_labels_str') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index badaac9b57..c45b0e636b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1003,6 +1003,7 @@ class MessageAgentThought(db.Model): position = db.Column(db.Integer, nullable=False) thought = db.Column(db.Text, nullable=True) tool = db.Column(db.Text, nullable=True) + tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) tool_input = db.Column(db.Text, nullable=True) observation = db.Column(db.Text, nullable=True) # plugin_id = db.Column(UUID, nullable=True) ## for future design @@ -1030,6 +1031,16 @@ class MessageAgentThought(db.Model): return json.loads(self.message_files) else: return [] + + @property + def tool_labels(self) -> dict: + try: + if self.tool_labels_str: + return json.loads(self.tool_labels_str) + else: + return {} + except Exception as e: + return {} class DatasetRetrieverResource(db.Model): __tablename__ = 'dataset_retriever_resources'