diff --git a/.github/workflows/tool-tests.yaml b/.github/workflows/tool-tests.yaml new file mode 100644 index 0000000000..3ea7c2492a --- /dev/null +++ b/.github/workflows/tool-tests.yaml @@ -0,0 +1,26 @@ +name: Run Tool Pytest + +on: + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + cache-dependency-path: ./api/requirements.txt + + - name: Install dependencies + run: pip install -r ./api/requirements.txt + + - name: Run pytest + run: pytest ./api/tests/integration_tests/tools/test_all_provider.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c4376631ba..e1c087a6cd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -91,6 +91,8 @@ To validate your set up, head over to [http://localhost:3000](http://localhost:3 If you are adding a model provider, [this guide](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md) is for you. +If you are adding a tool provider to Agent or Workflow, [this guide](./api/core/tools/README.md) is for you. + To help you quickly navigate where your contribution fits, a brief, annotated outline of Dify's backend & frontend is as follows: ### Backend diff --git a/api/app.py b/api/app.py index b7234b6a17..e46cb84bb8 100644 --- a/api/app.py +++ b/api/app.py @@ -30,7 +30,7 @@ from flask import Flask, Response, request from flask_cors import CORS from libs.passport import PassportService # DO NOT REMOVE BELOW -from models import account, dataset, model, source, task, tool, web +from models import account, dataset, model, source, task, tool, web, tools from services.account_service import AccountService # DO NOT REMOVE ABOVE diff --git a/api/commands.py b/api/commands.py index a9a04c9f13..1a2af23da2 100644 --- a/api/commands.py +++ b/api/commands.py @@ -22,7 +22,7 @@ from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair from models.account import InvitationCode, Tenant, TenantAccountJoin from models.dataset import Dataset, DatasetCollectionBinding, DatasetQuery, Document -from models.model import Account, App, AppModelConfig, Message, MessageAnnotation +from models.model import Account, App, AppModelConfig, Message, MessageAnnotation, InstalledApp from models.provider import Provider, ProviderModel, ProviderQuotaType, ProviderType from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType from tqdm import tqdm @@ -775,6 +775,66 @@ def add_annotation_question_field_value(): click.echo( click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green')) +@click.command('migrate-universal-chat-to-installed-app', help='Migrate universal chat to installed app.') +@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") +def migrate_universal_chat_to_installed_app(batch_size): + total_records = db.session.query(App).filter( + App.is_universal == True + ).count() + if total_records == 0: + click.secho("No data to migrate.", fg='green') + return + + num_batches = (total_records + batch_size - 1) // batch_size + + with tqdm(total=total_records, desc="Migrating Data") as pbar: + for i in range(num_batches): + offset = i * batch_size + limit = min(batch_size, total_records - offset) + + click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green') + + data_batch: list[App] = db.session.query(App) \ + .filter(App.is_universal == True) \ + .order_by(App.created_at) \ + .offset(offset).limit(limit).all() + + if not data_batch: + click.secho("No more data to migrate.", fg='green') + break + + try: + click.secho(f"Migrating {len(data_batch)} records...", fg='green') + for data in data_batch: + # check if the app is already installed + installed_app = db.session.query(InstalledApp).filter( + InstalledApp.app_id == data.id + ).first() + + if installed_app: + continue + + # insert installed app + installed_app = InstalledApp( + app_id=data.id, + tenant_id=data.tenant_id, + position=0, + app_owner_tenant_id=data.tenant_id, + is_pinned=True, + last_used_at=datetime.datetime.utcnow(), + ) + + db.session.add(installed_app) + + db.session.commit() + + except Exception as e: + click.secho(f"Error while migrating data: {e}, app_id: {data.id}", fg='red') + continue + + click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green') + + pbar.update(len(data_batch)) def register_commands(app): app.cli.add_command(reset_password) @@ -791,3 +851,4 @@ def register_commands(app): app.cli.add_command(migrate_default_input_to_dataset_query_variable) app.cli.add_command(add_qdrant_full_text_index) app.cli.add_command(add_annotation_question_field_value) + app.cli.add_command(migrate_universal_chat_to_installed_app) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 1394452c80..651d45737a 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -16,7 +16,5 @@ from .billing import billing from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing # Import explore controllers from .explore import audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message -# Import universal chat controllers -from .universal_chat import audio, chat, conversation, message, parameter # Import workspace controllers from .workspace import account, members, model_providers, models, tool_providers, workspace diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 6ae0ef4806..fe3f145df1 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -16,14 +16,15 @@ from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db from fields.app_fields import (app_detail_fields, app_detail_fields_with_site, app_pagination_fields, template_list_fields) +from flask import current_app from flask_login import current_user from flask_restful import Resource, abort, inputs, marshal_with, reqparse from libs.login import login_required from models.model import App, AppModelConfig, Site +from models.tools import ApiToolProvider from services.app_model_config_service import AppModelConfigService from werkzeug.exceptions import Forbidden - def _get_app(app_id, tenant_id): app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() if not app: @@ -42,14 +43,30 @@ class AppListApi(Resource): parser = reqparse.RequestParser() parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') + parser.add_argument('mode', type=str, choices=['chat', 'completion', 'all'], default='all', location='args', required=False) + parser.add_argument('name', type=str, location='args', required=False) args = parser.parse_args() + filters = [ + App.tenant_id == current_user.current_tenant_id, + ] + + if args['mode'] == 'completion': + filters.append(App.mode == 'completion') + elif args['mode'] == 'chat': + filters.append(App.mode == 'chat') + else: + pass + + if 'name' in args and args['name']: + filters.append(App.name.ilike(f'%{args["name"]}%')) + app_models = db.paginate( - db.select(App).where(App.tenant_id == current_user.current_tenant_id, - App.is_universal == False).order_by(App.created_at.desc()), + db.select(App).where(*filters).order_by(App.created_at.desc()), page=args['page'], per_page=args['limit'], - error_out=False) + error_out=False + ) return app_models @@ -62,7 +79,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('mode', type=str, choices=['completion', 'chat'], location='json') + parser.add_argument('mode', type=str, choices=['completion', 'chat', 'assistant'], location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') parser.add_argument('model_config', type=dict, location='json') @@ -178,7 +195,7 @@ class AppListApi(Resource): app_was_created.send(app) return app, 201 - + class AppTemplateApi(Resource): diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 7bde88efbe..44c54427a4 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -33,8 +33,9 @@ class InstalledAppsListApi(Resource): 'app_owner_tenant_id': installed_app.app_owner_tenant_id, 'is_pinned': installed_app.is_pinned, 'last_used_at': installed_app.last_used_at, - "editable": current_user.role in ["owner", "admin"], - "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id + 'editable': current_user.role in ["owner", "admin"], + 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id, + 'is_agent': installed_app.is_agent } for installed_app in installed_apps ] diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 05c5619ce7..404b855bb2 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -17,9 +17,9 @@ from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from flask import Response, stream_with_context from flask_login import current_user -from flask_restful import marshal_with, reqparse +from flask_restful import marshal_with, reqparse, fields from flask_restful.inputs import int_range -from libs.helper import uuid_value +from libs.helper import uuid_value, TimestampField from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError @@ -29,7 +29,6 @@ from werkzeug.exceptions import InternalServerError, NotFound class MessageListApi(InstalledAppResource): - @marshal_with(message_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app @@ -51,7 +50,6 @@ class MessageListApi(InstalledAppResource): except services.errors.message.FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") - class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 7be6966129..0a76f9d58a 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,10 +1,14 @@ # -*- coding:utf-8 -*- +import json + from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource from flask import current_app from flask_restful import fields, marshal_with -from models.model import InstalledApp +from models.model import InstalledApp, AppModelConfig +from models.tools import ApiToolProvider +from extensions.ext_database import db class AppParameterApi(InstalledAppResource): """Resource for app variables.""" @@ -58,5 +62,42 @@ class AppParameterApi(InstalledAppResource): } } +class ExploreAppMetaApi(InstalledAppResource): + def get(self, installed_app: InstalledApp): + """Get app meta""" + app_model_config: AppModelConfig = installed_app.app.app_model_config + + agent_config = app_model_config.agent_mode_dict or {} + meta = { + 'tool_icons': {} + } + + # get all tools + tools = agent_config.get('tools', []) + url_prefix = (current_app.config.get("CONSOLE_API_URL") + + f"/console/api/workspaces/current/tool-provider/builtin/") + for tool in tools: + keys = list(tool.keys()) + if len(keys) >= 4: + # current tool standard + provider_type = tool.get('provider_type') + provider_id = tool.get('provider_id') + tool_name = tool.get('tool_name') + if provider_type == 'builtin': + meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' + elif provider_type == 'api': + try: + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.id == provider_id + ) + meta['tool_icons'][tool_name] = json.loads(provider.icon) + except: + meta['tool_icons'][tool_name] = { + "background": "#252525", + "content": "\ud83d\ude01" + } + + return meta api.add_resource(AppParameterApi, '/installed-apps//parameters', endpoint='installed_app_parameters') +api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 71d77ee74a..92e64996b6 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -29,7 +29,8 @@ recommended_app_fields = { 'is_listed': fields.Boolean, 'install_count': fields.Integer, 'installed': fields.Boolean, - 'editable': fields.Boolean + 'editable': fields.Boolean, + 'is_agent': fields.Boolean } recommended_app_list_fields = { @@ -82,6 +83,7 @@ class RecommendedAppListApi(Resource): 'install_count': recommended_app.install_count, 'installed': installed, 'editable': current_user.role in ['owner', 'admin'], + "is_agent": app.is_agent } recommended_apps_result.append(recommended_app_result) diff --git a/api/controllers/console/universal_chat/audio.py b/api/controllers/console/universal_chat/audio.py index 2566448d49..0ef1e4ecee 100644 --- a/api/controllers/console/universal_chat/audio.py +++ b/api/controllers/console/universal_chat/audio.py @@ -60,5 +60,3 @@ class UniversalChatAudioApi(UniversalChatResource): logging.exception("internal server error.") raise InternalServerError() - -api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text') \ No newline at end of file diff --git a/api/controllers/console/universal_chat/chat.py b/api/controllers/console/universal_chat/chat.py index e0adc05c32..e69de29bb2 100644 --- a/api/controllers/console/universal_chat/chat.py +++ b/api/controllers/console/universal_chat/chat.py @@ -1,120 +0,0 @@ -import json -import logging -from typing import Generator, Union - -import services -from controllers.console import api -from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) -from controllers.console.universal_chat.wraps import UniversalChatResource -from core.application_queue_manager import ApplicationQueueManager -from core.entities.application_entities import InvokeFrom -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import reqparse -from libs.helper import uuid_value -from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound - - -class UniversalChatApi(UniversalChatResource): - def post(self, universal_app): - app_model = universal_app - - parser = reqparse.RequestParser() - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('provider', type=str, required=True, location='json') - parser.add_argument('model', type=str, required=True, location='json') - parser.add_argument('tools', type=list, required=True, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json') - args = parser.parse_args() - - app_model_config = app_model.app_model_config - - # update app model config - args['model_config'] = app_model_config.to_dict() - args['model_config']['model']['name'] = args['model'] - args['model_config']['model']['provider'] = args['provider'] - args['model_config']['agent_mode']['tools'] = args['tools'] - - if not args['model_config']['agent_mode']['tools']: - args['model_config']['agent_mode']['tools'] = [ - { - "current_datetime": { - "enabled": True - } - } - ] - else: - args['model_config']['agent_mode']['tools'].append({ - "current_datetime": { - "enabled": True - } - }) - - args['inputs'] = {} - - del args['model'] - del args['tools'] - - args['auto_generate_name'] = False - - try: - response = CompletionService.completion( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True, - is_model_config_override=True, - ) - - return compact_response(response) - except services.errors.conversation.ConversationNotExistsError: - raise NotFound("Conversation Not Exists.") - except services.errors.conversation.ConversationCompletedError: - raise ConversationCompletedError() - except services.errors.app_model_config.AppModelConfigBrokenError: - logging.exception("App model config broken.") - raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise e - except Exception as e: - logging.exception("internal server error.") - raise InternalServerError() - - -class UniversalChatStopApi(UniversalChatResource): - def post(self, universal_app, task_id): - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - - return {'result': 'success'}, 200 - - -def compact_response(response: Union[dict, Generator]) -> Response: - if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') - else: - def generate() -> Generator: - for chunk in response: - yield chunk - - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') - - -api.add_resource(UniversalChatApi, '/universal-chat/messages') -api.add_resource(UniversalChatStopApi, '/universal-chat/messages//stop') diff --git a/api/controllers/console/universal_chat/conversation.py b/api/controllers/console/universal_chat/conversation.py deleted file mode 100644 index af141b67a5..0000000000 --- a/api/controllers/console/universal_chat/conversation.py +++ /dev/null @@ -1,110 +0,0 @@ -# -*- coding:utf-8 -*- -from controllers.console import api -from controllers.console.universal_chat.wraps import UniversalChatResource -from fields.conversation_fields import (conversation_with_model_config_fields, - conversation_with_model_config_infinite_scroll_pagination_fields) -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value -from services.conversation_service import ConversationService -from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError -from services.web_conversation_service import WebConversationService -from werkzeug.exceptions import NotFound - - -class UniversalChatConversationListApi(UniversalChatResource): - - @marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields) - def get(self, universal_app): - app_model = universal_app - - parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') - args = parser.parse_args() - - pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False - - try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=current_user, - last_id=args['last_id'], - limit=args['limit'], - pinned=pinned - ) - except LastConversationNotExistsError: - raise NotFound("Last Conversation Not Exists.") - - -class UniversalChatConversationApi(UniversalChatResource): - def delete(self, universal_app, c_id): - app_model = universal_app - conversation_id = str(c_id) - - try: - ConversationService.delete(app_model, conversation_id, current_user) - except ConversationNotExistsError: - raise NotFound("Conversation Not Exists.") - - WebConversationService.unpin(app_model, conversation_id, current_user) - - return {"result": "success"}, 204 - - -class UniversalChatConversationRenameApi(UniversalChatResource): - - @marshal_with(conversation_with_model_config_fields) - def post(self, universal_app, c_id): - app_model = universal_app - conversation_id = str(c_id) - - parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') - args = parser.parse_args() - - try: - return ConversationService.rename( - app_model, - conversation_id, - current_user, - args['name'], - args['auto_generate'] - ) - except ConversationNotExistsError: - raise NotFound("Conversation Not Exists.") - - -class UniversalChatConversationPinApi(UniversalChatResource): - - def patch(self, universal_app, c_id): - app_model = universal_app - conversation_id = str(c_id) - - try: - WebConversationService.pin(app_model, conversation_id, current_user) - except ConversationNotExistsError: - raise NotFound("Conversation Not Exists.") - - return {"result": "success"} - - -class UniversalChatConversationUnPinApi(UniversalChatResource): - def patch(self, universal_app, c_id): - app_model = universal_app - conversation_id = str(c_id) - WebConversationService.unpin(app_model, conversation_id, current_user) - - return {"result": "success"} - - -api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations//name') -api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations') -api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/') -api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations//pin') -api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations//unpin') diff --git a/api/controllers/console/universal_chat/message.py b/api/controllers/console/universal_chat/message.py deleted file mode 100644 index 503615d751..0000000000 --- a/api/controllers/console/universal_chat/message.py +++ /dev/null @@ -1,145 +0,0 @@ -# -*- coding:utf-8 -*- -import logging - -import services -from controllers.console import api -from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) -from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError -from controllers.console.universal_chat.wraps import UniversalChatResource -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value -from services.errors.conversation import ConversationNotExistsError -from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError -from services.message_service import MessageService -from werkzeug.exceptions import InternalServerError, NotFound - - -class UniversalChatMessageListApi(UniversalChatResource): - feedback_fields = { - 'rating': fields.String - } - - agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_input': fields.String, - 'created_at': TimestampField - } - - retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField - } - - message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)) - } - - message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) - } - - @marshal_with(message_infinite_scroll_pagination_fields) - def get(self, universal_app): - app_model = universal_app - - parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - args = parser.parse_args() - - try: - return MessageService.pagination_by_first_id(app_model, current_user, - args['conversation_id'], args['first_id'], args['limit']) - except services.errors.conversation.ConversationNotExistsError: - raise NotFound("Conversation Not Exists.") - except services.errors.message.FirstMessageNotExistsError: - raise NotFound("First Message Not Exists.") - - -class UniversalChatMessageFeedbackApi(UniversalChatResource): - def post(self, universal_app, message_id): - app_model = universal_app - message_id = str(message_id) - - parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') - args = parser.parse_args() - - try: - MessageService.create_feedback(app_model, message_id, current_user, args['rating']) - except services.errors.message.MessageNotExistsError: - raise NotFound("Message Not Exists.") - - return {'result': 'success'} - - -class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource): - def get(self, universal_app, message_id): - app_model = universal_app - message_id = str(message_id) - - try: - questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=current_user, - message_id=message_id - ) - except MessageNotExistsError: - raise NotFound("Message not found") - except ConversationNotExistsError: - raise NotFound("Conversation not found") - except SuggestedQuestionsAfterAnswerDisabledError: - raise AppSuggestedQuestionsAfterAnswerDisabledError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) - except Exception: - logging.exception("internal server error.") - raise InternalServerError() - - return {'data': questions} - - -api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages') -api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages//feedbacks') -api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages//suggested-questions') diff --git a/api/controllers/console/universal_chat/parameter.py b/api/controllers/console/universal_chat/parameter.py deleted file mode 100644 index dca86e39c1..0000000000 --- a/api/controllers/console/universal_chat/parameter.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding:utf-8 -*- -import json - -from controllers.console import api -from controllers.console.universal_chat.wraps import UniversalChatResource -from flask_restful import fields, marshal_with -from models.model import App - - -class UniversalChatParameterApi(UniversalChatResource): - """Resource for app variables.""" - parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw - } - - @marshal_with(parameters_fields) - def get(self, universal_app: App): - """Retrieve app parameters.""" - app_model = universal_app - app_model_config = app_model.app_model_config - app_model_config.retriever_resource = json.dumps({'enabled': True}) - - return { - 'opening_statement': app_model_config.opening_statement, - 'suggested_questions': app_model_config.suggested_questions_list, - 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict, - 'retriever_resource': app_model_config.retriever_resource_dict, - 'annotation_reply': app_model_config.annotation_reply_dict, - } - - -api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters') diff --git a/api/controllers/console/universal_chat/wraps.py b/api/controllers/console/universal_chat/wraps.py deleted file mode 100644 index 3e5600639e..0000000000 --- a/api/controllers/console/universal_chat/wraps.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -from functools import wraps - -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required -from extensions.ext_database import db -from flask_login import current_user -from flask_restful import Resource -from libs.login import login_required -from models.model import App, AppModelConfig - - -def universal_chat_app_required(view=None): - def decorator(view): - @wraps(view) - def decorated(*args, **kwargs): - # get universal chat app - universal_app = db.session.query(App).filter( - App.tenant_id == current_user.current_tenant_id, - App.is_universal == True - ).first() - - if universal_app is None: - # create universal app if not exists - universal_app = App( - tenant_id=current_user.current_tenant_id, - name='Universal Chat', - mode='chat', - is_universal=True, - icon='', - icon_background='', - api_rpm=0, - api_rph=0, - enable_site=False, - enable_api=False, - status='normal' - ) - - db.session.add(universal_app) - db.session.flush() - - app_model_config = AppModelConfig( - provider="", - model_id="", - configs={}, - opening_statement='', - suggested_questions=json.dumps([]), - suggested_questions_after_answer=json.dumps({'enabled': True}), - speech_to_text=json.dumps({'enabled': True}), - retriever_resource=json.dumps({'enabled': True}), - more_like_this=None, - sensitive_word_avoidance=None, - model=json.dumps({ - "provider": "openai", - "name": "gpt-3.5-turbo-16k", - "completion_params": { - "max_tokens": 800, - "temperature": 0.8, - "top_p": 1, - "presence_penalty": 0, - "frequency_penalty": 0 - } - }), - user_input_form=json.dumps([]), - pre_prompt='', - agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}), - ) - - app_model_config.app_id = universal_app.id - db.session.add(app_model_config) - db.session.flush() - - universal_app.app_model_config_id = app_model_config.id - db.session.commit() - - return view(universal_app, *args, **kwargs) - return decorated - - if view: - return decorator(view) - return decorator - - -class UniversalChatResource(Resource): - # must be reversed if there are multiple decorators - method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required] diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c6416e1d3b..b694e0bef7 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,136 +1,293 @@ import json +from libs.login import login_required +from flask_login import current_user +from flask_restful import Resource, reqparse +from flask import send_file +from werkzeug.exceptions import Forbidden + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.tool.provider.errors import ToolValidateFailedError -from core.tool.provider.tool_provider_service import ToolProviderService -from extensions.ext_database import db -from flask_login import current_user -from flask_restful import Resource, abort, reqparse -from libs.login import login_required -from models.tool import ToolProvider, ToolProviderName -from werkzeug.exceptions import Forbidden +from services.tools_manage_service import ToolManageService + +import io class ToolProviderListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): + user_id = current_user.id tenant_id = current_user.current_tenant_id - tool_credential_dict = {} - for tool_name in ToolProviderName: - tool_credential_dict[tool_name.value] = { - 'tool_name': tool_name.value, - 'is_enabled': False, - 'credentials': None - } + return ToolManageService.list_tool_providers(user_id, tenant_id) - tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all() +class ToolBuiltinProviderListToolsApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + user_id = current_user.id + tenant_id = current_user.current_tenant_id - for p in tool_providers: - if p.is_enabled: - tool_credential_dict[p.tool_name] = { - 'tool_name': p.tool_name, - 'is_enabled': p.is_enabled, - 'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True) - } - - return list(tool_credential_dict.values()) - - -class ToolProviderCredentialsApi(Resource): + return ToolManageService.list_builtin_tool_provider_tools( + user_id, + tenant_id, + provider, + ) +class ToolBuiltinProviderDeleteApi(Resource): @setup_required @login_required @account_initialization_required def post(self, provider): - if provider not in [p.value for p in ToolProviderName]: - abort(404) - - # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: - raise Forbidden(f'User {current_user.id} is not authorized to update provider token, ' - f'current_role is {current_user.current_tenant.current_role}') - - parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - args = parser.parse_args() - + raise Forbidden() + + user_id = current_user.id tenant_id = current_user.current_tenant_id - tool_provider_service = ToolProviderService(tenant_id, provider) - - try: - tool_provider_service.credentials_validate(args['credentials']) - except ToolValidateFailedError as ex: - raise ValueError(str(ex)) - - encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials'])) - - tenant = current_user.current_tenant - - tool_provider_model = db.session.query(ToolProvider).filter( - ToolProvider.tenant_id == tenant.id, - ToolProvider.tool_name == provider, - ).first() - - # Only allow updating token for CUSTOM provider type - if tool_provider_model: - tool_provider_model.encrypted_credentials = encrypted_credentials - tool_provider_model.is_enabled = True - else: - tool_provider_model = ToolProvider( - tenant_id=tenant.id, - tool_name=provider, - encrypted_credentials=encrypted_credentials, - is_enabled=True - ) - db.session.add(tool_provider_model) - - db.session.commit() - - return {'result': 'success'}, 201 - - -class ToolProviderCredentialsValidateApi(Resource): - + return ToolManageService.delete_builtin_tool_provider( + user_id, + tenant_id, + provider, + ) + +class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @account_initialization_required def post(self, provider): - if provider not in [p.value for p in ToolProviderName]: - abort(404) + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() - result = True - error = None + return ToolManageService.update_builtin_tool_provider( + user_id, + tenant_id, + provider, + args['credentials'], + ) +class ToolBuiltinProviderIconApi(Resource): + @setup_required + def get(self, provider): + icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider) + return send_file(io.BytesIO(icon_bytes), mimetype=minetype) + + +class ToolApiProviderAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + user_id = current_user.id tenant_id = current_user.current_tenant_id - tool_provider_service = ToolProviderService(tenant_id, provider) + parser = reqparse.RequestParser() + parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') + parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') + parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') - try: - tool_provider_service.credentials_validate(args['credentials']) - except ToolValidateFailedError as ex: - result = False - error = str(ex) + args = parser.parse_args() - response = {'result': 'success' if result else 'error'} + return ToolManageService.create_api_tool_provider( + user_id, + tenant_id, + args['provider'], + args['icon'], + args['credentials'], + args['schema_type'], + args['schema'], + args.get('privacy_policy', ''), + ) - if not result: - response['error'] = error +class ToolApiProviderGetRemoteSchemaApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() - return response + parser.add_argument('url', type=str, required=True, nullable=False, location='args') + args = parser.parse_args() + + return ToolManageService.get_api_tool_provider_remote_schema( + current_user.id, + current_user.current_tenant_id, + args['url'], + ) + +class ToolApiProviderListToolsApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + + parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + + args = parser.parse_args() + + return ToolManageService.list_api_tool_provider_tools( + user_id, + tenant_id, + args['provider'], + ) + +class ToolApiProviderUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') + parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') + parser.add_argument('icon', type=str, required=True, nullable=False, location='json') + parser.add_argument('privacy_policy', type=str, required=True, nullable=False, location='json') + + args = parser.parse_args() + + return ToolManageService.update_api_tool_provider( + user_id, + tenant_id, + args['provider'], + args['original_provider'], + args['icon'], + args['credentials'], + args['schema_type'], + args['schema'], + args['privacy_policy'], + ) + +class ToolApiProviderDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + + parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + + args = parser.parse_args() + + return ToolManageService.delete_api_tool_provider( + user_id, + tenant_id, + args['provider'], + ) + +class ToolApiProviderGetApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + + parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + + args = parser.parse_args() + + return ToolManageService.get_api_tool_provider( + user_id, + tenant_id, + args['provider'], + ) + +class ToolBuiltinProviderCredentialsSchemaApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return ToolManageService.list_builtin_provider_credentials_schema(provider) + +class ToolApiProviderSchemaApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + + parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + + args = parser.parse_args() + + return ToolManageService.parser_api_schema( + schema=args['schema'], + ) + +class ToolApiProviderPreviousTestApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + + parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') + parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') + parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') + parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + + args = parser.parse_args() + + return ToolManageService.test_api_tool_preview( + current_user.current_tenant_id, + args['tool_name'], + args['credentials'], + args['parameters'], + args['schema_type'], + args['schema'], + ) api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') -api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers//credentials') -api.add_resource(ToolProviderCredentialsValidateApi, - '/workspaces/current/tool-providers//credentials-validate') +api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') +api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') +api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') +api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') +api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') +api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') +api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') +api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') +api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') +api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') +api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') +api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') +api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index ff8d54c726..7c3e848b53 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -7,3 +7,4 @@ api = ExternalApi(bp) from . import image_preview +from . import tool_files \ No newline at end of file diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py new file mode 100644 index 0000000000..f8b1936fc9 --- /dev/null +++ b/api/controllers/files/tool_files.py @@ -0,0 +1,47 @@ +from controllers.files import api +from flask import Response +from flask_restful import Resource, reqparse +from libs.exception import BaseHTTPException +from werkzeug.exceptions import NotFound, Forbidden + +from core.tools.tool_file_manager import ToolFileManager + +class ToolFilePreviewApi(Resource): + def get(self, file_id, extension): + file_id = str(file_id) + + parser = reqparse.RequestParser() + + parser.add_argument('timestamp', type=str, required=True, location='args') + parser.add_argument('nonce', type=str, required=True, location='args') + parser.add_argument('sign', type=str, required=True, location='args') + + args = parser.parse_args() + + if not ToolFileManager.verify_file(file_id=file_id, + timestamp=args['timestamp'], + nonce=args['nonce'], + sign=args['sign'], + ): + raise Forbidden('Invalid request.') + + try: + result = ToolFileManager.get_file_generator_by_message_file_id( + file_id, + ) + + if not result: + raise NotFound(f'file is not found') + + generator, mimetype = result + except Exception: + raise UnsupportedFileTypeError() + + return Response(generator, mimetype=mimetype) + +api.add_resource(ToolFilePreviewApi, '/files/tools/.') + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = 'unsupported_file_type' + description = "File type not allowed." + code = 415 diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 2809a9135b..0be38d3083 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -3,7 +3,12 @@ from controllers.service_api import api from controllers.service_api.wraps import AppApiResource from flask import current_app from flask_restful import fields, marshal_with -from models.model import App +from models.model import App, AppModelConfig +from models.tools import ApiToolProvider + +import json + +from extensions.ext_database import db class AppParameterApi(AppApiResource): @@ -58,5 +63,42 @@ class AppParameterApi(AppApiResource): } } +class AppMetaApi(AppApiResource): + def get(self, app_model: App, end_user): + """Get app meta""" + app_model_config: AppModelConfig = app_model.app_model_config + + agent_config = app_model_config.agent_mode_dict or {} + meta = { + 'tool_icons': {} + } + + # get all tools + tools = agent_config.get('tools', []) + url_prefix = (current_app.config.get("CONSOLE_API_URL") + + f"/console/api/workspaces/current/tool-provider/builtin/") + for tool in tools: + keys = list(tool.keys()) + if len(keys) >= 4: + # current tool standard + provider_type = tool.get('provider_type') + provider_id = tool.get('provider_id') + tool_name = tool.get('tool_name') + if provider_type == 'builtin': + meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' + elif provider_type == 'api': + try: + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.id == provider_id + ) + meta['tool_icons'][tool_name] = json.loads(provider.icon) + except: + meta['tool_icons'][tool_name] = { + "background": "#252525", + "content": "\ud83d\ude01" + } + + return meta api.add_resource(AppParameterApi, '/parameters') +api.add_resource(AppMetaApi, '/meta') diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 07c4318b84..aca995f993 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -37,6 +37,19 @@ class MessageListApi(AppApiResource): 'created_at': TimestampField } + agent_thought_fields = { + 'id': fields.String, + 'chain_id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'thought': fields.String, + 'tool': fields.String, + 'tool_input': fields.String, + 'created_at': TimestampField, + 'observation': fields.String, + 'message_files': fields.List(fields.String, attribute='files') + } + message_fields = { 'id': fields.String, 'conversation_id': fields.String, @@ -46,7 +59,8 @@ class MessageListApi(AppApiResource): 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField + 'created_at': TimestampField, + 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)) } message_infinite_scroll_pagination_fields = { diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 22b274c72d..a51259f2e4 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -3,7 +3,12 @@ from controllers.web import api from controllers.web.wraps import WebApiResource from flask import current_app from flask_restful import fields, marshal_with -from models.model import App +from models.model import App, AppModelConfig +from models.tools import ApiToolProvider + +from extensions.ext_database import db + +import json class AppParameterApi(WebApiResource): @@ -57,5 +62,42 @@ class AppParameterApi(WebApiResource): } } +class AppMeta(WebApiResource): + def get(self, app_model: App, end_user): + """Get app meta""" + app_model_config: AppModelConfig = app_model.app_model_config + + agent_config = app_model_config.agent_mode_dict or {} + meta = { + 'tool_icons': {} + } + + # get all tools + tools = agent_config.get('tools', []) + url_prefix = (current_app.config.get("CONSOLE_API_URL") + + f"/console/api/workspaces/current/tool-provider/builtin/") + for tool in tools: + keys = list(tool.keys()) + if len(keys) >= 4: + # current tool standard + provider_type = tool.get('provider_type') + provider_id = tool.get('provider_id') + tool_name = tool.get('tool_name') + if provider_type == 'builtin': + meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' + elif provider_type == 'api': + try: + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.id == provider_id + ) + meta['tool_icons'][tool_name] = json.loads(provider.icon) + except: + meta['tool_icons'][tool_name] = { + "background": "#252525", + "content": "\ud83d\ude01" + } + + return meta api.add_resource(AppParameterApi, '/parameters') +api.add_resource(AppMeta, '/meta') \ No newline at end of file diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index f354933946..2712e84691 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -14,6 +14,7 @@ from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields +from fields.message_fields import agent_thought_fields from flask import Response, stream_with_context from flask_restful import fields, marshal_with, reqparse from flask_restful.inputs import int_range @@ -59,7 +60,8 @@ class MessageListApi(WebApiResource): 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField + 'created_at': TimestampField, + 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)) } message_infinite_scroll_pagination_fields = { diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index 4aa48337e5..2565fb2315 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -13,8 +13,8 @@ from core.entities.message_entities import prompt_messages_to_lc_messages from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError -from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool -from core.tool.dataset_retriever_tool import DatasetRetrieverTool +from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool +from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from langchain.agents import AgentExecutor as LCAgentExecutor from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent from langchain.callbacks.manager import Callbacks diff --git a/api/core/app_runner/agent_app_runner.py b/api/core/app_runner/agent_app_runner.py deleted file mode 100644 index cc375056ce..0000000000 --- a/api/core/app_runner/agent_app_runner.py +++ /dev/null @@ -1,251 +0,0 @@ -import json -import logging -from typing import cast - -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.app_runner.app_runner import AppRunner -from core.application_queue_manager import ApplicationQueueManager -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler -from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, PromptTemplateEntity -from core.features.agent_runner import AgentRunnerFeature -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from extensions.ext_database import db -from models.model import App, Conversation, Message, MessageAgentThought, MessageChain - -logger = logging.getLogger(__name__) - - -class AgentApplicationRunner(AppRunner): - """ - Agent Application Runner - """ - - def run(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - conversation: Conversation, - message: Message) -> None: - """ - Run agent application - :param application_generate_entity: application generate entity - :param queue_manager: application queue manager - :param conversation: conversation - :param message: message - :return: - """ - app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() - if not app_record: - raise ValueError(f"App not found") - - app_orchestration_config = application_generate_entity.app_orchestration_config_entity - - inputs = application_generate_entity.inputs - query = application_generate_entity.query - files = application_generate_entity.files - - # Pre-calculate the number of tokens of the prompt messages, - # and return the rest number of tokens by model context token size limit and max token size limit. - # If the rest number of tokens is not enough, raise exception. - # Include: prompt template, inputs, query(optional), files(optional) - # Not Include: memory, external data, dataset context - self.get_pre_calculate_rest_tokens( - app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, - inputs=inputs, - files=files, - query=query - ) - - memory = None - if application_generate_entity.conversation_id: - # get memory of conversation (read-only) - model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model - ) - - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) - - # reorganize all inputs and template to prompt messages - # Include: prompt template, inputs, query(optional), files(optional) - # memory(optional) - prompt_messages, stop = self.organize_prompt_messages( - app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, - inputs=inputs, - files=files, - query=query, - context=None, - memory=memory - ) - - # Create MessageChain - message_chain = self._init_message_chain( - message=message, - query=query - ) - - # add agent callback to record agent thoughts - agent_callback = AgentLoopGatherCallbackHandler( - model_config=app_orchestration_config.model_config, - message=message, - queue_manager=queue_manager, - message_chain=message_chain - ) - - # init LLM Callback - agent_llm_callback = AgentLLMCallback( - agent_callback=agent_callback - ) - - agent_runner = AgentRunnerFeature( - tenant_id=application_generate_entity.tenant_id, - app_orchestration_config=app_orchestration_config, - model_config=app_orchestration_config.model_config, - config=app_orchestration_config.agent, - queue_manager=queue_manager, - message=message, - user_id=application_generate_entity.user_id, - agent_llm_callback=agent_llm_callback, - callback=agent_callback, - memory=memory - ) - - # agent run - result = agent_runner.run( - query=query, - invoke_from=application_generate_entity.invoke_from - ) - - if result: - self._save_message_chain( - message_chain=message_chain, - output_text=result - ) - - if (result - and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE - and app_orchestration_config.prompt_template.simple_prompt_template - ): - # Direct output if agent result exists and has pre prompt - self.direct_output( - queue_manager=queue_manager, - app_orchestration_config=app_orchestration_config, - prompt_messages=prompt_messages, - stream=application_generate_entity.stream, - text=result, - usage=self._get_usage_of_all_agent_thoughts( - model_config=app_orchestration_config.model_config, - message=message - ) - ) - else: - # As normal LLM run, agent result as context - context = result - - # reorganize all inputs and template to prompt messages - # Include: prompt template, inputs, query(optional), files(optional) - # memory(optional), external data, dataset context(optional) - prompt_messages, stop = self.organize_prompt_messages( - app_record=app_record, - model_config=app_orchestration_config.model_config, - prompt_template_entity=app_orchestration_config.prompt_template, - inputs=inputs, - files=files, - query=query, - context=context, - memory=memory - ) - - # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recale_llm_max_tokens( - model_config=app_orchestration_config.model_config, - prompt_messages=prompt_messages - ) - - # Invoke model - model_instance = ModelInstance( - provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, - model=app_orchestration_config.model_config.model - ) - - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=app_orchestration_config.model_config.parameters, - stop=stop, - stream=application_generate_entity.stream, - user=application_generate_entity.user_id, - ) - - # handle invoke result - self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream - ) - - def _init_message_chain(self, message: Message, query: str) -> MessageChain: - """ - Init MessageChain - :param message: message - :param query: query - :return: - """ - message_chain = MessageChain( - message_id=message.id, - type="AgentExecutor", - input=json.dumps({ - "input": query - }) - ) - - db.session.add(message_chain) - db.session.commit() - - return message_chain - - def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None: - """ - Save MessageChain - :param message_chain: message chain - :param output_text: output text - :return: - """ - message_chain.output = json.dumps({ - "output": output_text - }) - db.session.commit() - - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, - message: Message) -> LLMUsage: - """ - Get usage of all agent thoughts - :param model_config: model config - :param message: message - :return: - """ - agent_thoughts = (db.session.query(MessageAgentThought) - .filter(MessageAgentThought.message_id == message.id).all()) - - all_message_tokens = 0 - all_answer_tokens = 0 - for agent_thought in agent_thoughts: - all_message_tokens += agent_thought.message_token - all_answer_tokens += agent_thought.answer_token - - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - return model_type_instance._calc_response_usage( - model_config.model, - model_config.credentials, - all_message_tokens, - all_answer_tokens - ) diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index a2edbfc3ab..c7c5474b2a 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -2,7 +2,8 @@ import time from typing import Generator, List, Optional, Tuple, Union, cast from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.application_entities import AppOrchestrationConfigEntity, ModelConfigEntity, PromptTemplateEntity +from core.entities.application_entities import AppOrchestrationConfigEntity, ModelConfigEntity, \ + PromptTemplateEntity, ExternalDataVariableEntity, ApplicationGenerateEntity, InvokeFrom from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -10,9 +11,12 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage, from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.features.hosting_moderation import HostingModerationFeature +from core.features.moderation import ModerationFeature +from core.features.external_data_fetch import ExternalDataFetchFeature +from core.features.annotation_reply import AnnotationReplyFeature from core.prompt.prompt_transform import PromptTransform -from models.model import App - +from models.model import App, MessageAnnotation, Message class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, @@ -199,7 +203,8 @@ class AppRunner: def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], queue_manager: ApplicationQueueManager, - stream: bool) -> None: + stream: bool, + agent: bool = False) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -210,16 +215,19 @@ class AppRunner: if not stream: self._handle_invoke_result_direct( invoke_result=invoke_result, - queue_manager=queue_manager + queue_manager=queue_manager, + agent=agent ) else: self._handle_invoke_result_stream( invoke_result=invoke_result, - queue_manager=queue_manager + queue_manager=queue_manager, + agent=agent ) def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: ApplicationQueueManager) -> None: + queue_manager: ApplicationQueueManager, + agent: bool) -> None: """ Handle invoke result direct :param invoke_result: invoke result @@ -232,7 +240,8 @@ class AppRunner: ) def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: ApplicationQueueManager) -> None: + queue_manager: ApplicationQueueManager, + agent: bool) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -244,7 +253,10 @@ class AppRunner: text = '' usage = None for result in invoke_result: - queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER) + if not agent: + queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER) + else: + queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER) text += result.delta.message.content @@ -271,3 +283,101 @@ class AppRunner: llm_result=llm_result, pub_from=PublishFrom.APPLICATION_MANAGER ) + + def moderation_for_inputs(self, app_id: str, + tenant_id: str, + app_orchestration_config_entity: AppOrchestrationConfigEntity, + inputs: dict, + query: str) -> Tuple[bool, dict, str]: + """ + Process sensitive_word_avoidance. + :param app_id: app id + :param tenant_id: tenant id + :param app_orchestration_config_entity: app orchestration config entity + :param inputs: inputs + :param query: query + :return: + """ + moderation_feature = ModerationFeature() + return moderation_feature.check( + app_id=app_id, + tenant_id=tenant_id, + app_orchestration_config_entity=app_orchestration_config_entity, + inputs=inputs, + query=query, + ) + + def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, + queue_manager: ApplicationQueueManager, + prompt_messages: list[PromptMessage]) -> bool: + """ + Check hosting moderation + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param prompt_messages: prompt messages + :return: + """ + hosting_moderation_feature = HostingModerationFeature() + moderation_result = hosting_moderation_feature.check( + application_generate_entity=application_generate_entity, + prompt_messages=prompt_messages + ) + + if moderation_result: + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=application_generate_entity.app_orchestration_config_entity, + prompt_messages=prompt_messages, + text="I apologize for any confusion, " \ + "but I'm an AI assistant to be helpful, harmless, and honest.", + stream=application_generate_entity.stream + ) + + return moderation_result + + def fill_in_inputs_from_external_data_tools(self, tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str) -> dict: + """ + Fill in variable inputs from external data tools if exists. + + :param tenant_id: workspace id + :param app_id: app id + :param external_data_tools: external data tools configs + :param inputs: the inputs + :param query: the query + :return: the filled inputs + """ + external_data_fetch_feature = ExternalDataFetchFeature() + return external_data_fetch_feature.fetch( + tenant_id=tenant_id, + app_id=app_id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query + ) + + def query_app_annotations_to_reply(self, app_record: App, + message: Message, + query: str, + user_id: str, + invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + """ + Query app annotations to reply + :param app_record: app record + :param message: message + :param query: query + :param user_id: user id + :param invoke_from: invoke from + :return: + """ + annotation_reply_feature = AnnotationReplyFeature() + return annotation_reply_feature.query( + app_record=app_record, + message=message, + query=query, + user_id=user_id, + invoke_from=invoke_from + ) \ No newline at end of file diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py new file mode 100644 index 0000000000..24da556369 --- /dev/null +++ b/api/core/app_runner/assistant_app_runner.py @@ -0,0 +1,342 @@ +import json +import logging +from typing import cast + +from core.app_runner.app_runner import AppRunner +from core.features.assistant_cot_runner import AssistantCotApplicationRunner +from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner +from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \ + AgentEntity +from core.application_queue_manager import ApplicationQueueManager, PublishFrom +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.moderation.base import ModerationException +from core.tools.entities.tool_entities import ToolRuntimeVariablePool +from extensions.ext_database import db +from models.model import Conversation, Message, App, MessageChain, MessageAgentThought +from models.tools import ToolConversationVariables + +logger = logging.getLogger(__name__) + +class AssistantApplicationRunner(AppRunner): + """ + Assistant Application Runner + """ + def run(self, application_generate_entity: ApplicationGenerateEntity, + queue_manager: ApplicationQueueManager, + conversation: Conversation, + message: Message) -> None: + """ + Run assistant application + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param conversation: conversation + :param message: message + :return: + """ + app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() + if not app_record: + raise ValueError(f"App not found") + + app_orchestration_config = application_generate_entity.app_orchestration_config_entity + + inputs = application_generate_entity.inputs + query = application_generate_entity.query + files = application_generate_entity.files + + # Pre-calculate the number of tokens of the prompt messages, + # and return the rest number of tokens by model context token size limit and max token size limit. + # If the rest number of tokens is not enough, raise exception. + # Include: prompt template, inputs, query(optional), files(optional) + # Not Include: memory, external data, dataset context + self.get_pre_calculate_rest_tokens( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query + ) + + memory = None + if application_generate_entity.conversation_id: + # get memory of conversation (read-only) + model_instance = ModelInstance( + provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, + model=app_orchestration_config.model_config.model + ) + + memory = TokenBufferMemory( + conversation=conversation, + model_instance=model_instance + ) + + # organize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional) + prompt_messages, _ = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory + ) + + # moderation + try: + # process sensitive_word_avoidance + _, inputs, query = self.moderation_for_inputs( + app_id=app_record.id, + tenant_id=application_generate_entity.tenant_id, + app_orchestration_config_entity=app_orchestration_config, + inputs=inputs, + query=query, + ) + except ModerationException as e: + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=app_orchestration_config, + prompt_messages=prompt_messages, + text=str(e), + stream=application_generate_entity.stream + ) + return + + if query: + # annotation reply + annotation_reply = self.query_app_annotations_to_reply( + app_record=app_record, + message=message, + query=query, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from + ) + + if annotation_reply: + queue_manager.publish_annotation_reply( + message_annotation_id=annotation_reply.id, + pub_from=PublishFrom.APPLICATION_MANAGER + ) + self.direct_output( + queue_manager=queue_manager, + app_orchestration_config=app_orchestration_config, + prompt_messages=prompt_messages, + text=annotation_reply.content, + stream=application_generate_entity.stream + ) + return + + # fill in variable inputs from external data tools if exists + external_data_tools = app_orchestration_config.external_data_variables + if external_data_tools: + inputs = self.fill_in_inputs_from_external_data_tools( + tenant_id=app_record.tenant_id, + app_id=app_record.id, + external_data_tools=external_data_tools, + inputs=inputs, + query=query + ) + + # reorganize all inputs and template to prompt messages + # Include: prompt template, inputs, query(optional), files(optional) + # memory(optional), external data, dataset context(optional) + prompt_messages, _ = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory + ) + + # check hosting moderation + hosting_moderation_result = self.check_hosting_moderation( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + prompt_messages=prompt_messages + ) + + if hosting_moderation_result: + return + + agent_entity = app_orchestration_config.agent + + # load tool variables + tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, + user_id=application_generate_entity.user_id, + tanent_id=application_generate_entity.tenant_id) + + # convert db variables to tool variables + tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) + + message_chain = self._init_message_chain( + message=message, + query=query + ) + + # init model instance + model_instance = ModelInstance( + provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, + model=app_orchestration_config.model_config.model + ) + prompt_message, _ = self.organize_prompt_messages( + app_record=app_record, + model_config=app_orchestration_config.model_config, + prompt_template_entity=app_orchestration_config.prompt_template, + inputs=inputs, + files=files, + query=query, + memory=memory, + ) + + # start agent runner + if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: + assistant_cot_runner = AssistantCotApplicationRunner( + tenant_id=application_generate_entity.tenant_id, + application_generate_entity=application_generate_entity, + app_orchestration_config=app_orchestration_config, + model_config=app_orchestration_config.model_config, + config=agent_entity, + queue_manager=queue_manager, + message=message, + user_id=application_generate_entity.user_id, + memory=memory, + prompt_messages=prompt_message, + variables_pool=tool_variables, + db_variables=tool_conversation_variables, + ) + invoke_result = assistant_cot_runner.run( + model_instance=model_instance, + conversation=conversation, + message=message, + query=query, + ) + elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: + assistant_cot_runner = AssistantFunctionCallApplicationRunner( + tenant_id=application_generate_entity.tenant_id, + application_generate_entity=application_generate_entity, + app_orchestration_config=app_orchestration_config, + model_config=app_orchestration_config.model_config, + config=agent_entity, + queue_manager=queue_manager, + message=message, + user_id=application_generate_entity.user_id, + memory=memory, + prompt_messages=prompt_message, + variables_pool=tool_variables, + db_variables=tool_conversation_variables + ) + invoke_result = assistant_cot_runner.run( + model_instance=model_instance, + conversation=conversation, + message=message, + query=query, + ) + + # handle invoke result + self._handle_invoke_result( + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + agent=True + ) + + def _load_tool_variables(self, conversation_id: str, user_id: str, tanent_id: str) -> ToolConversationVariables: + """ + load tool variables from database + """ + tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( + ToolConversationVariables.conversation_id == conversation_id, + ToolConversationVariables.tenant_id == tanent_id + ).first() + + if tool_variables: + # save tool variables to session, so that we can update it later + db.session.add(tool_variables) + else: + # create new tool variables + tool_variables = ToolConversationVariables( + conversation_id=conversation_id, + user_id=user_id, + tenant_id=tanent_id, + variables_str='[]', + ) + db.session.add(tool_variables) + db.session.commit() + + return tool_variables + + def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool: + """ + convert db variables to tool variables + """ + return ToolRuntimeVariablePool(**{ + 'conversation_id': db_variables.conversation_id, + 'user_id': db_variables.user_id, + 'tenant_id': db_variables.tenant_id, + 'pool': db_variables.variables + }) + + def _init_message_chain(self, message: Message, query: str) -> MessageChain: + """ + Init MessageChain + :param message: message + :param query: query + :return: + """ + message_chain = MessageChain( + message_id=message.id, + type="AgentExecutor", + input=json.dumps({ + "input": query + }) + ) + + db.session.add(message_chain) + db.session.commit() + + return message_chain + + def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None: + """ + Save MessageChain + :param message_chain: message chain + :param output_text: output text + :return: + """ + message_chain.output = json.dumps({ + "output": output_text + }) + db.session.commit() + + def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, + message: Message) -> LLMUsage: + """ + Get usage of all agent thoughts + :param model_config: model config + :param message: message + :return: + """ + agent_thoughts = (db.session.query(MessageAgentThought) + .filter(MessageAgentThought.message_id == message.id).all()) + + all_message_tokens = 0 + all_answer_tokens = 0 + for agent_thought in agent_thoughts: + all_message_tokens += agent_thought.message_tokens + all_answer_tokens += agent_thought.answer_tokens + + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + return model_type_instance._calc_response_usage( + model_config.model, + model_config.credentials, + all_message_tokens, + all_answer_tokens + ) diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index 97d684a37f..0864786530 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -1,23 +1,18 @@ import logging -from typing import Optional, Tuple +from typing import Optional from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity, - ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity) -from core.features.annotation_reply import AnnotationReplyFeature +from core.entities.application_entities import (ApplicationGenerateEntity, DatasetEntity, + InvokeFrom, ModelConfigEntity) from core.features.dataset_retrieval import DatasetRetrievalFeature -from core.features.external_data_fetch import ExternalDataFetchFeature -from core.features.hosting_moderation import HostingModerationFeature -from core.features.moderation import ModerationFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessage from core.moderation.base import ModerationException from core.prompt.prompt_transform import AppMode from extensions.ext_database import db -from models.model import App, Conversation, Message, MessageAnnotation +from models.model import App, Conversation, Message logger = logging.getLogger(__name__) @@ -213,76 +208,6 @@ class BasicApplicationRunner(AppRunner): stream=application_generate_entity.stream ) - def moderation_for_inputs(self, app_id: str, - tenant_id: str, - app_orchestration_config_entity: AppOrchestrationConfigEntity, - inputs: dict, - query: str) -> Tuple[bool, dict, str]: - """ - Process sensitive_word_avoidance. - :param app_id: app id - :param tenant_id: tenant id - :param app_orchestration_config_entity: app orchestration config entity - :param inputs: inputs - :param query: query - :return: - """ - moderation_feature = ModerationFeature() - return moderation_feature.check( - app_id=app_id, - tenant_id=tenant_id, - app_orchestration_config_entity=app_orchestration_config_entity, - inputs=inputs, - query=query, - ) - - def query_app_annotations_to_reply(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: - """ - Query app annotations to reply - :param app_record: app record - :param message: message - :param query: query - :param user_id: user id - :param invoke_from: invoke from - :return: - """ - annotation_reply_feature = AnnotationReplyFeature() - return annotation_reply_feature.query( - app_record=app_record, - message=message, - query=query, - user_id=user_id, - invoke_from=invoke_from - ) - - def fill_in_inputs_from_external_data_tools(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: - """ - Fill in variable inputs from external data tools if exists. - - :param tenant_id: workspace id - :param app_id: app id - :param external_data_tools: external data tools configs - :param inputs: the inputs - :param query: the query - :return: the filled inputs - """ - external_data_fetch_feature = ExternalDataFetchFeature() - return external_data_fetch_feature.fetch( - tenant_id=tenant_id, - app_id=app_id, - external_data_tools=external_data_tools, - inputs=inputs, - query=query - ) - def retrieve_dataset_context(self, tenant_id: str, app_record: App, queue_manager: ApplicationQueueManager, @@ -334,31 +259,4 @@ class BasicApplicationRunner(AppRunner): hit_callback=hit_callback, memory=memory ) - - def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, - queue_manager: ApplicationQueueManager, - prompt_messages: list[PromptMessage]) -> bool: - """ - Check hosting moderation - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param prompt_messages: prompt messages - :return: - """ - hosting_moderation_feature = HostingModerationFeature() - moderation_result = hosting_moderation_feature.check( - application_generate_entity=application_generate_entity, - prompt_messages=prompt_messages - ) - - if moderation_result: - self.direct_output( - queue_manager=queue_manager, - app_orchestration_config=application_generate_entity.app_orchestration_config_entity, - prompt_messages=prompt_messages, - text="I apologize for any confusion, " \ - "but I'm an AI assistant to be helpful, harmless, and honest.", - stream=application_generate_entity.stream - ) - - return moderation_result + \ No newline at end of file diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 0e84c1c4c3..58a99a52c2 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -8,7 +8,8 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentThoughtEvent, QueueErrorEvent, QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent, - QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent) + QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, + QueueMessageFileEvent, QueueAgentMessageEvent) from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, @@ -16,11 +17,12 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage TextPromptMessageContent) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.tools.tool_file_manager import ToolFileManager from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.prompt_template import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought +from models.model import Conversation, Message, MessageAgentThought, MessageFile from pydantic import BaseModel from services.annotation_service import AppAnnotationService @@ -284,6 +286,7 @@ class GenerateTaskPipeline: .filter(MessageAgentThought.id == event.agent_thought_id) .first() ) + db.session.refresh(agent_thought) if agent_thought: response = { @@ -293,16 +296,48 @@ class GenerateTaskPipeline: 'message_id': self._message.id, 'position': agent_thought.position, 'thought': agent_thought.thought, + 'observation': agent_thought.observation, 'tool': agent_thought.tool, 'tool_input': agent_thought.tool_input, - 'created_at': int(self._message.created_at.timestamp()) + 'created_at': int(self._message.created_at.timestamp()), + 'message_files': agent_thought.files } if self._conversation.mode == 'chat': response['conversation_id'] = self._conversation.id yield self._yield_response(response) - elif isinstance(event, QueueMessageEvent): + elif isinstance(event, QueueMessageFileEvent): + message_file: MessageFile = ( + db.session.query(MessageFile) + .filter(MessageFile.id == event.message_file_id) + .first() + ) + # get extension + if '.' in message_file.url: + extension = f'.{message_file.url.split(".")[-1]}' + if len(extension) > 10: + extension = '.bin' + else: + extension = '.bin' + # add sign url + url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension) + + if message_file: + response = { + 'event': 'message_file', + 'id': message_file.id, + 'type': message_file.type, + 'belongs_to': message_file.belongs_to or 'user', + 'url': url + } + + if self._conversation.mode == 'chat': + response['conversation_id'] = self._conversation.id + + yield self._yield_response(response) + + elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: @@ -332,7 +367,7 @@ class GenerateTaskPipeline: self._output_moderation_handler.append_new_token(delta_text) self._task_state.llm_result.message.content += delta_text - response = self._handle_chunk(delta_text) + response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent)) yield self._yield_response(response) elif isinstance(event, QueueMessageReplaceEvent): response = { @@ -384,14 +419,14 @@ class GenerateTaskPipeline: extras=self._application_generate_entity.extras ) - def _handle_chunk(self, text: str) -> dict: + def _handle_chunk(self, text: str, agent: bool = False) -> dict: """ Handle completed event. :param text: text :return: """ response = { - 'event': 'message', + 'event': 'message' if not agent else 'agent_message', 'id': self._message.id, 'task_id': self._application_generate_entity.task_id, 'message_id': self._message.id, diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 7a0bed3ded..100725f6d7 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -4,7 +4,7 @@ import threading import uuid from typing import Any, Generator, Optional, Tuple, Union, cast -from core.app_runner.agent_app_runner import AgentApplicationRunner +from core.app_runner.assistant_app_runner import AssistantApplicationRunner from core.app_runner.basic_app_runner import BasicApplicationRunner from core.app_runner.generate_task_pipeline import GenerateTaskPipeline from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom @@ -13,7 +13,7 @@ from core.entities.application_entities import (AdvancedChatPromptTemplateEntity ApplicationGenerateEntity, AppOrchestrationConfigEntity, DatasetEntity, DatasetRetrieveConfigEntity, ExternalDataVariableEntity, FileUploadEntity, InvokeFrom, ModelConfigEntity, PromptTemplateEntity, - SensitiveWordAvoidanceEntity) + SensitiveWordAvoidanceEntity, AgentPromptEntity) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileObj @@ -23,6 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_template import PromptTemplateParser from core.provider_manager import ProviderManager +from core.tools.prompt.template import REACT_PROMPT_TEMPLATES from extensions.ext_database import db from flask import Flask, current_app from models.account import Account @@ -93,6 +94,9 @@ class ApplicationManager: extras=extras ) + if not stream and application_generate_entity.app_orchestration_config_entity.agent: + raise ValueError("Agent app is not supported in blocking mode.") + # init generate records ( conversation, @@ -151,7 +155,7 @@ class ApplicationManager: if application_generate_entity.app_orchestration_config_entity.agent: # agent app - runner = AgentApplicationRunner() + runner = AssistantApplicationRunner() runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, @@ -354,6 +358,8 @@ class ApplicationManager: # external data variables properties['external_data_variables'] = [] + + # old external_data_tools external_data_tools = copy_app_model_config_dict.get('external_data_tools', []) for external_data_tool in external_data_tools: if 'enabled' not in external_data_tool or not external_data_tool['enabled']: @@ -366,6 +372,19 @@ class ApplicationManager: config=external_data_tool['config'] ) ) + + # current external_data_tools + for variable in copy_app_model_config_dict.get('user_input_form', []): + typ = list(variable.keys())[0] + if typ == 'external_data_tool': + val = variable[typ] + properties['external_data_variables'].append( + ExternalDataVariableEntity( + variable=val['variable'], + type=val['type'], + config=val['config'] + ) + ) # show retrieve source show_retrieve_source = False @@ -375,15 +394,64 @@ class ApplicationManager: show_retrieve_source = True properties['show_retrieve_source'] = show_retrieve_source + + dataset_ids = [] + if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}): + datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', { + 'strategy': 'router', + 'datasets': [] + }) + + + for dataset in datasets.get('datasets', []): + keys = list(dataset.keys()) + if len(keys) == 0 or keys[0] != 'dataset': + continue + dataset = dataset['dataset'] + + if 'enabled' not in dataset or not dataset['enabled']: + continue + + dataset_id = dataset.get('id', None) + if dataset_id: + dataset_ids.append(dataset_id) + else: + datasets = {'strategy': 'router', 'datasets': []} if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][ 'enabled']: - agent_dict = copy_app_model_config_dict.get('agent_mode') - agent_strategy = agent_dict.get('strategy', 'router') - if agent_strategy in ['router', 'react_router']: - dataset_ids = [] - for tool in agent_dict.get('tools', []): + agent_dict = copy_app_model_config_dict.get('agent_mode', {}) + agent_strategy = agent_dict.get('strategy', 'cot') + + if agent_strategy == 'function_call': + strategy = AgentEntity.Strategy.FUNCTION_CALLING + elif agent_strategy == 'cot' or agent_strategy == 'react': + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + else: + # old configs, try to detect default strategy + if copy_app_model_config_dict['model']['provider'] == 'openai': + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + + agent_tools = [] + for tool in agent_dict.get('tools', []): + keys = tool.keys() + if len(keys) >= 4: + if "enabled" not in tool or not tool["enabled"]: + continue + + agent_tool_properties = { + 'provider_type': tool['provider_type'], + 'provider_id': tool['provider_id'], + 'tool_name': tool['tool_name'], + 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} + } + + agent_tools.append(AgentToolEntity(**agent_tool_properties)) + elif len(keys) == 1: + # old standard key = list(tool.keys())[0] if key != 'dataset': @@ -397,58 +465,57 @@ class ApplicationManager: dataset_id = tool_item['id'] dataset_ids.append(dataset_id) - dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'}) - query_variable = copy_app_model_config_dict.get('dataset_query_variable') - if dataset_configs['retrieval_model'] == 'single': - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ), - single_strategy=agent_strategy - ) - ) - else: - properties['dataset'] = DatasetEntity( - dataset_ids=dataset_ids, - retrieve_config=DatasetRetrieveConfigEntity( - query_variable=query_variable, - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ), - top_k=dataset_configs.get('top_k'), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model') - ) - ) + agent_prompt = agent_dict.get('prompt', None) or {} + # check model mode + model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion') + if model_mode == 'completion': + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), + next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']), + ) else: - if agent_strategy == 'react': - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - strategy = AgentEntity.Strategy.FUNCTION_CALLING + agent_prompt_entity = AgentPromptEntity( + first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), + next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + ) - agent_tools = [] - for tool in agent_dict.get('tools', []): - key = list(tool.keys())[0] - tool_item = tool[key] + properties['agent'] = AgentEntity( + provider=properties['model_config'].provider, + model=properties['model_config'].model, + strategy=strategy, + prompt=agent_prompt_entity, + tools=agent_tools, + max_iteration=agent_dict.get('max_iteration', 5) + ) - agent_tool_properties = { - "tool_id": key - } + if len(dataset_ids) > 0: + # dataset configs + dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'}) + query_variable = copy_app_model_config_dict.get('dataset_query_variable') - if "enabled" not in tool_item or not tool_item["enabled"]: - continue - - agent_tool_properties["config"] = tool_item - agent_tools.append(AgentToolEntity(**agent_tool_properties)) - - properties['agent'] = AgentEntity( - provider=properties['model_config'].provider, - model=properties['model_config'].model, - strategy=strategy, - tools=agent_tools + if dataset_configs['retrieval_model'] == 'single': + properties['dataset'] = DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs['retrieval_model'] + ), + single_strategy=datasets.get('strategy', 'router') + ) + ) + else: + properties['dataset'] = DatasetEntity( + dataset_ids=dataset_ids, + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( + dataset_configs['retrieval_model'] + ), + top_k=dataset_configs.get('top_k'), + score_threshold=dataset_configs.get('score_threshold'), + reranking_model=dataset_configs.get('reranking_model') + ) ) # file upload @@ -601,6 +668,7 @@ class ApplicationManager: message_id=message.id, type=file.type.value, transfer_method=file.transfer_method.value, + belongs_to='user', url=file.url, upload_file_id=file.upload_file_id, created_by_role=('account' if account_id else 'end_user'), diff --git a/api/core/application_queue_manager.py b/api/core/application_queue_manager.py index 09b92c5f84..605255f3bf 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/application_queue_manager.py @@ -7,10 +7,10 @@ from core.entities.application_entities import InvokeFrom from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentThoughtEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent, QueueMessageEvent, QueueMessageReplaceEvent, QueuePingEvent, QueueRetrieverResourcesEvent, - QueueStopEvent) + QueueStopEvent, QueueMessageFileEvent, QueueAgentMessageEvent) from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from extensions.ext_redis import redis_client -from models.model import MessageAgentThought +from models.model import MessageAgentThought, MessageFile from sqlalchemy.orm import DeclarativeMeta @@ -96,6 +96,18 @@ class ApplicationQueueManager: chunk=chunk ), pub_from) + def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None: + """ + Publish agent chunk message to channel + + :param chunk: chunk + :param pub_from: publish from + :return: + """ + self.publish(QueueAgentMessageEvent( + chunk=chunk + ), pub_from) + def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None: """ Publish message replace @@ -144,6 +156,17 @@ class ApplicationQueueManager: agent_thought_id=message_agent_thought.id ), pub_from) + def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None: + """ + Publish agent thought + :param message_file: message file + :param pub_from: publish from + :return: + """ + self.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), pub_from) + def publish_error(self, e, pub_from: PublishFrom) -> None: """ Publish error diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py new file mode 100644 index 0000000000..c1037035ee --- /dev/null +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -0,0 +1,74 @@ +import os +from typing import Any, Dict, Optional, Union +from pydantic import BaseModel + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.input import print_text + +class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): + """Callback Handler that prints to std out.""" + color: Optional[str] = '' + current_loop = 1 + + def __init__(self, color: Optional[str] = None) -> None: + super().__init__() + """Initialize callback handler.""" + # use a specific color is not specified + self.color = color or 'green' + self.current_loop = 1 + + def on_tool_start( + self, + tool_name: str, + tool_inputs: Dict[str, Any], + ) -> None: + """Do nothing.""" + print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) + + def on_tool_end( + self, + tool_name: str, + tool_inputs: Dict[str, Any], + tool_outputs: str, + ) -> None: + """If not the final action, print out observation.""" + print_text("\n[on_tool_end]\n", color=self.color) + print_text("Tool: " + tool_name + "\n", color=self.color) + print_text("Inputs: " + str(tool_inputs) + "\n", color=self.color) + print_text("Outputs: " + str(tool_outputs) + "\n", color=self.color) + print_text("\n") + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Do nothing.""" + print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red') + + def on_agent_start( + self, thought: str + ) -> None: + """Run on agent start.""" + if thought: + print_text("\n[on_agent_start] \nCurrent Loop: " + \ + str(self.current_loop) + \ + "\nThought: " + thought + "\n", color=self.color) + else: + print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) + + def on_agent_finish( + self, color: Optional[str] = None, **kwargs: Any + ) -> None: + """Run on agent end.""" + print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) + + self.current_loop += 1 + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + + @property + def ignore_chat_model(self) -> bool: + """Whether to ignore chat model callbacks.""" + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 47a1ac6510..95a9d90f97 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -1,11 +1,12 @@ from enum import Enum -from typing import Any, Optional, cast +from typing import Optional, Any, cast, Literal, Union + +from pydantic import BaseModel from core.entities.provider_configuration import ProviderModelBundle from core.file.file_obj import FileObj from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import AIModelEntity -from pydantic import BaseModel class ModelConfigEntity(BaseModel): @@ -153,9 +154,35 @@ class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ - tool_id: str - config: dict[str, Any] = {} + provider_type: Literal["builtin", "api"] + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] = {} +class AgentPromptEntity(BaseModel): + """ + Agent Prompt Entity. + """ + first_prompt: str + next_iteration: str + +class AgentScratchpadUnit(BaseModel): + """ + Agent First Prompt Entity. + """ + + class Action(BaseModel): + """ + Action Entity. + """ + action_name: str + action_input: Union[dict, str] + + agent_response: Optional[str] = None + thought: Optional[str] = None + action_str: Optional[str] = None + observation: Optional[str] = None + action: Optional[Action] = None class AgentEntity(BaseModel): """ @@ -171,8 +198,9 @@ class AgentEntity(BaseModel): provider: str model: str strategy: Strategy - tools: list[AgentToolEntity] = [] - + prompt: Optional[AgentPromptEntity] = None + tools: list[AgentToolEntity] = None + max_iteration: int = 5 class AppOrchestrationConfigEntity(BaseModel): """ diff --git a/api/core/entities/queue_entities.py b/api/core/entities/queue_entities.py index 858b00ea64..d6ef28b138 100644 --- a/api/core/entities/queue_entities.py +++ b/api/core/entities/queue_entities.py @@ -10,11 +10,13 @@ class QueueEvent(Enum): QueueEvent enum """ MESSAGE = "message" + AGENT_MESSAGE = "agent_message" MESSAGE_REPLACE = "message-replace" MESSAGE_END = "message-end" RETRIEVER_RESOURCES = "retriever-resources" ANNOTATION_REPLY = "annotation-reply" AGENT_THOUGHT = "agent-thought" + MESSAGE_FILE = "message-file" ERROR = "error" PING = "ping" STOP = "stop" @@ -33,7 +35,14 @@ class QueueMessageEvent(AppQueueEvent): """ event = QueueEvent.MESSAGE chunk: LLMResultChunk - + +class QueueAgentMessageEvent(AppQueueEvent): + """ + QueueMessageEvent entity + """ + event = QueueEvent.AGENT_MESSAGE + chunk: LLMResultChunk + class QueueMessageReplaceEvent(AppQueueEvent): """ @@ -73,7 +82,13 @@ class QueueAgentThoughtEvent(AppQueueEvent): """ event = QueueEvent.AGENT_THOUGHT agent_thought_id: str - + +class QueueMessageFileEvent(AppQueueEvent): + """ + QueueAgentThoughtEvent entity + """ + event = QueueEvent.MESSAGE_FILE + message_file_id: str class QueueErrorEvent(AppQueueEvent): """ diff --git a/api/core/features/agent_runner.py b/api/core/features/agent_runner.py index ba9c3218fa..2a7e373dbf 100644 --- a/api/core/features/agent_runner.py +++ b/api/core/features/agent_runner.py @@ -1,30 +1,27 @@ import logging -from typing import List, Optional, cast +from typing import cast, Optional, List + +from langchain import WikipediaAPIWrapper +from langchain.callbacks.base import BaseCallbackHandler +from langchain.tools import BaseTool, WikipediaQueryRun, Tool +from pydantic import BaseModel, Field from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy +from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor from core.application_queue_manager import ApplicationQueueManager from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.entities.application_entities import (AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity, InvokeFrom, - ModelConfigEntity) +from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \ + AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.tool.current_datetime_tool import DatetimeTool -from core.tool.dataset_retriever_tool import DatasetRetrieverTool -from core.tool.provider.serpapi_provider import SerpAPIToolProvider -from core.tool.serpapi_wrapper import OptimizedSerpAPIInput, OptimizedSerpAPIWrapper -from core.tool.web_reader_tool import WebReaderTool +from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db -from langchain import WikipediaAPIWrapper -from langchain.callbacks.base import BaseCallbackHandler -from langchain.tools import BaseTool, Tool, WikipediaQueryRun from models.dataset import Dataset from models.model import Message -from pydantic import BaseModel, Field logger = logging.getLogger(__name__) @@ -132,55 +129,6 @@ class AgentRunnerFeature: logger.exception("agent_executor run failed") return None - def to_tools(self, tool_configs: list[AgentToolEntity], - invoke_from: InvokeFrom, - callbacks: list[BaseCallbackHandler]) \ - -> Optional[List[BaseTool]]: - """ - Convert tool configs to tools - :param tool_configs: tool configs - :param invoke_from: invoke from - :param callbacks: callbacks - """ - tools = [] - for tool_config in tool_configs: - tool = None - if tool_config.tool_id == "dataset": - tool = self.to_dataset_retriever_tool( - tool_config=tool_config.config, - invoke_from=invoke_from - ) - elif tool_config.tool_id == "web_reader": - tool = self.to_web_reader_tool( - tool_config=tool_config.config, - invoke_from=invoke_from - ) - elif tool_config.tool_id == "google_search": - tool = self.to_google_search_tool( - tool_config=tool_config.config, - invoke_from=invoke_from - ) - elif tool_config.tool_id == "wikipedia": - tool = self.to_wikipedia_tool( - tool_config=tool_config.config, - invoke_from=invoke_from - ) - elif tool_config.tool_id == "current_datetime": - tool = self.to_current_datetime_tool( - tool_config=tool_config.config, - invoke_from=invoke_from - ) - - if tool: - if tool.callbacks is not None: - tool.callbacks.extend(callbacks) - else: - tool.callbacks = callbacks - - tools.append(tool) - - return tools - def to_dataset_retriever_tool(self, tool_config: dict, invoke_from: InvokeFrom) \ -> Optional[BaseTool]: @@ -247,78 +195,4 @@ class AgentRunnerFeature: retriever_from=invoke_from.to_source() ) - return tool - - def to_web_reader_tool(self, tool_config: dict, - invoke_from: InvokeFrom) -> Optional[BaseTool]: - """ - A tool for reading web pages - :param tool_config: tool config - :param invoke_from: invoke from - :return: - """ - model_parameters = { - "temperature": 0, - "max_tokens": 500 - } - - tool = WebReaderTool( - model_config=self.model_config, - model_parameters=model_parameters, - max_chunk_length=4000, - continue_reading=True - ) - - return tool - - def to_google_search_tool(self, tool_config: dict, - invoke_from: InvokeFrom) -> Optional[BaseTool]: - """ - A tool for performing a Google search and extracting snippets and webpages - :param tool_config: tool config - :param invoke_from: invoke from - :return: - """ - tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id) - func_kwargs = tool_provider.credentials_to_func_kwargs() - if not func_kwargs: - return None - - tool = Tool( - name="google_search", - description="A tool for performing a Google search and extracting snippets and webpages " - "when you need to search for something you don't know or when your information " - "is not up to date. " - "Input should be a search query.", - func=OptimizedSerpAPIWrapper(**func_kwargs).run, - args_schema=OptimizedSerpAPIInput - ) - - return tool - - def to_current_datetime_tool(self, tool_config: dict, - invoke_from: InvokeFrom) -> Optional[BaseTool]: - """ - A tool for getting the current date and time - :param tool_config: tool config - :param invoke_from: invoke from - :return: - """ - return DatetimeTool() - - def to_wikipedia_tool(self, tool_config: dict, - invoke_from: InvokeFrom) -> Optional[BaseTool]: - """ - A tool for searching Wikipedia - :param tool_config: tool config - :param invoke_from: invoke from - :return: - """ - class WikipediaInput(BaseModel): - query: str = Field(..., description="search query.") - - return WikipediaQueryRun( - name="wikipedia", - api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), - args_schema=WikipediaInput - ) + return tool \ No newline at end of file diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py new file mode 100644 index 0000000000..2a896d6fbd --- /dev/null +++ b/api/core/features/assistant_base_runner.py @@ -0,0 +1,558 @@ +import logging +import json + +from typing import Optional, List, Tuple, Union +from datetime import datetime +from mimetypes import guess_extension + +from core.app_runner.app_runner import AppRunner +from extensions.ext_database import db + +from models.model import MessageAgentThought, Message, MessageFile +from models.tools import ToolConversationVariables + +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, \ + ToolRuntimeVariablePool, ToolParamter +from core.tools.tool.tool import Tool +from core.tools.tool_manager import ToolManager +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool +from core.app_runner.app_runner import AppRunner +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.entities.application_entities import ModelConfigEntity, AgentEntity, AgentToolEntity +from core.application_queue_manager import ApplicationQueueManager +from core.memory.token_buffer_memory import TokenBufferMemory +from core.entities.application_entities import ModelConfigEntity, \ + AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.utils.encoders import jsonable_encoder +from core.file.message_file_parser import FileTransferMethod + +logger = logging.getLogger(__name__) + +class BaseAssistantApplicationRunner(AppRunner): + def __init__(self, tenant_id: str, + application_generate_entity: ApplicationGenerateEntity, + app_orchestration_config: AppOrchestrationConfigEntity, + model_config: ModelConfigEntity, + config: AgentEntity, + queue_manager: ApplicationQueueManager, + message: Message, + user_id: str, + memory: Optional[TokenBufferMemory] = None, + prompt_messages: Optional[List[PromptMessage]] = None, + variables_pool: Optional[ToolRuntimeVariablePool] = None, + db_variables: Optional[ToolConversationVariables] = None, + ) -> None: + """ + Agent runner + :param tenant_id: tenant id + :param app_orchestration_config: app orchestration config + :param model_config: model config + :param config: dataset config + :param queue_manager: queue manager + :param message: message + :param user_id: user id + :param agent_llm_callback: agent llm callback + :param callback: callback + :param memory: memory + """ + self.tenant_id = tenant_id + self.application_generate_entity = application_generate_entity + self.app_orchestration_config = app_orchestration_config + self.model_config = model_config + self.config = config + self.queue_manager = queue_manager + self.message = message + self.user_id = user_id + self.memory = memory + self.history_prompt_messages = prompt_messages + self.variables_pool = variables_pool + self.db_variables_pool = db_variables + + # init callback + self.agent_callback = DifyAgentCallbackHandler() + # init dataset tools + hit_callback = DatasetIndexToolCallbackHandler( + queue_manager=queue_manager, + app_id=self.application_generate_entity.app_id, + message_id=message.id, + user_id=user_id, + invoke_from=self.application_generate_entity.invoke_from, + ) + self.dataset_tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id=tenant_id, + dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [], + retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None, + return_resource=app_orchestration_config.show_retrieve_source, + invoke_from=application_generate_entity.invoke_from, + hit_callback=hit_callback + ) + # get how many agent thoughts have been created + self.agent_thought_count = db.session.query(MessageAgentThought).filter( + MessageAgentThought.message_id == self.message.id, + ).count() + + def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: + """ + Repacket app orchestration config + """ + if app_orchestration_config.prompt_template.simple_prompt_template is None: + app_orchestration_config.prompt_template.simple_prompt_template = '' + + return app_orchestration_config + + def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str: + """ + Handle tool response + """ + result = '' + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + result += response.message + elif response.type == ToolInvokeMessage.MessageType.LINK: + result += f"result link: {response.message}. please dirct user to check it." + elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + result += f"image has been created and sent to user already, you should tell user to check it now." + else: + result += f"tool response: {response.message}." + + return result + + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]: + """ + convert tool to prompt message tool + """ + tool_entity = ToolManager.get_tool_runtime( + provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name, + tanent_id=self.application_generate_entity.tenant_id, + agent_callback=self.agent_callback + ) + tool_entity.load_variables(self.variables_pool) + + message_tool = PromptMessageTool( + name=tool.tool_name, + description=tool_entity.description.llm, + parameters={ + "type": "object", + "properties": {}, + "required": [], + } + ) + + runtime_parameters = {} + + parameters = tool_entity.parameters or [] + user_parameters = tool_entity.get_runtime_parameters() or [] + + # override parameters + for parameter in user_parameters: + # check if parameter in tool parameters + found = False + for tool_parameter in parameters: + if tool_parameter.name == parameter.name: + found = True + break + + if found: + # override parameter + tool_parameter.type = parameter.type + tool_parameter.form = parameter.form + tool_parameter.required = parameter.required + tool_parameter.default = parameter.default + tool_parameter.options = parameter.options + tool_parameter.llm_description = parameter.llm_description + else: + # add new parameter + parameters.append(parameter) + + for parameter in parameters: + parameter_type = 'string' + enum = [] + if parameter.type == ToolParamter.ToolParameterType.STRING: + parameter_type = 'string' + elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN: + parameter_type = 'boolean' + elif parameter.type == ToolParamter.ToolParameterType.NUMBER: + parameter_type = 'number' + elif parameter.type == ToolParamter.ToolParameterType.SELECT: + for option in parameter.options: + enum.append(option.value) + parameter_type = 'string' + else: + raise ValueError(f"parameter type {parameter.type} is not supported") + + if parameter.form == ToolParamter.ToolParameterForm.FORM: + # get tool parameter from form + tool_parameter_config = tool.tool_parameters.get(parameter.name) + if not tool_parameter_config: + # get default value + tool_parameter_config = parameter.default + if not tool_parameter_config and parameter.required: + raise ValueError(f"tool parameter {parameter.name} not found in tool config") + + if parameter.type == ToolParamter.ToolParameterType.SELECT: + # check if tool_parameter_config in options + options = list(map(lambda x: x.value, parameter.options)) + if tool_parameter_config not in options: + raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") + + # convert tool parameter config to correct type + try: + if parameter.type == ToolParamter.ToolParameterType.NUMBER: + # check if tool parameter is integer + if isinstance(tool_parameter_config, int): + tool_parameter_config = tool_parameter_config + elif isinstance(tool_parameter_config, float): + tool_parameter_config = tool_parameter_config + elif isinstance(tool_parameter_config, str): + if '.' in tool_parameter_config: + tool_parameter_config = float(tool_parameter_config) + else: + tool_parameter_config = int(tool_parameter_config) + elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN: + tool_parameter_config = bool(tool_parameter_config) + elif parameter.type not in [ToolParamter.ToolParameterType.SELECT, ToolParamter.ToolParameterType.STRING]: + tool_parameter_config = str(tool_parameter_config) + elif parameter.type == ToolParamter.ToolParameterType: + tool_parameter_config = str(tool_parameter_config) + except Exception as e: + raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") + + # save tool parameter to tool entity memory + runtime_parameters[parameter.name] = tool_parameter_config + + elif parameter.form == ToolParamter.ToolParameterForm.LLM: + message_tool.parameters['properties'][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or '', + } + + if len(enum) > 0: + message_tool.parameters['properties'][parameter.name]['enum'] = enum + + if parameter.required: + message_tool.parameters['required'].append(parameter.name) + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + + return message_tool, tool_entity + + def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: + """ + convert dataset retriever tool to prompt message tool + """ + prompt_tool = PromptMessageTool( + name=tool.identity.name, + description=tool.description.llm, + parameters={ + "type": "object", + "properties": {}, + "required": [], + } + ) + + for parameter in tool.get_runtime_parameters(): + parameter_type = 'string' + + prompt_tool.parameters['properties'][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or '', + } + + if parameter.required: + if parameter.name not in prompt_tool.parameters['required']: + prompt_tool.parameters['required'].append(parameter.name) + + return prompt_tool + + def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: + """ + update prompt message tool + """ + # try to get tool runtime parameters + tool_runtime_parameters = tool.get_runtime_parameters() or [] + + for parameter in tool_runtime_parameters: + parameter_type = 'string' + enum = [] + if parameter.type == ToolParamter.ToolParameterType.STRING: + parameter_type = 'string' + elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN: + parameter_type = 'boolean' + elif parameter.type == ToolParamter.ToolParameterType.NUMBER: + parameter_type = 'number' + elif parameter.type == ToolParamter.ToolParameterType.SELECT: + for option in parameter.options: + enum.append(option.value) + parameter_type = 'string' + else: + raise ValueError(f"parameter type {parameter.type} is not supported") + + if parameter.form == ToolParamter.ToolParameterForm.LLM: + prompt_tool.parameters['properties'][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or '', + } + + if len(enum) > 0: + prompt_tool.parameters['properties'][parameter.name]['enum'] = enum + + if parameter.required: + if parameter.name not in prompt_tool.parameters['required']: + prompt_tool.parameters['required'].append(parameter.name) + + return prompt_tool + + def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]: + """ + Extract tool response binary + """ + result = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + result.append(ToolInvokeMessageBinary( + mimetype=response.meta.get('mime_type', 'octet/stream'), + url=response.message, + save_as=response.save_as, + )) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + result.append(ToolInvokeMessageBinary( + mimetype=response.meta.get('mime_type', 'octet/stream'), + url=response.message, + save_as=response.save_as, + )) + elif response.type == ToolInvokeMessage.MessageType.LINK: + # check if there is a mime type in meta + if response.meta and 'mime_type' in response.meta: + result.append(ToolInvokeMessageBinary( + mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', + url=response.message, + save_as=response.save_as, + )) + + return result + + def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]: + """ + Create message file + + :param messages: messages + :return: message files, should save as variable + """ + result = [] + + for message in messages: + file_type = 'bin' + if 'image' in message.mimetype: + file_type = 'image' + elif 'video' in message.mimetype: + file_type = 'video' + elif 'audio' in message.mimetype: + file_type = 'audio' + elif 'text' in message.mimetype: + file_type = 'text' + elif 'pdf' in message.mimetype: + file_type = 'pdf' + elif 'zip' in message.mimetype: + file_type = 'archive' + # ... + + invoke_from = self.application_generate_entity.invoke_from + + message_file = MessageFile( + message_id=self.message.id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE.value, + belongs_to='assistant', + url=message.url, + upload_file_id=None, + created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), + created_by=self.user_id, + ) + db.session.add(message_file) + result.append(( + message_file, + message.save_as + )) + + db.session.commit() + + return result + + def create_agent_thought(self, message_id: str, message: str, + tool_name: str, tool_input: str, messages_ids: List[str] + ) -> MessageAgentThought: + """ + Create agent thought + """ + thought = MessageAgentThought( + message_id=message_id, + message_chain_id=None, + thought='', + tool=tool_name, + tool_input=tool_input, + message=message, + message_token=0, + message_unit_price=0, + message_price_unit=0, + message_files=json.dumps(messages_ids) if messages_ids else '', + answer='', + observation='', + answer_token=0, + answer_unit_price=0, + answer_price_unit=0, + tokens=0, + total_price=0, + position=self.agent_thought_count + 1, + currency='USD', + latency=0, + created_by_role='account', + created_by=self.user_id, + ) + + db.session.add(thought) + db.session.commit() + + self.agent_thought_count += 1 + + return thought + + def save_agent_thought(self, + agent_thought: MessageAgentThought, + tool_name: str, + tool_input: Union[str, dict], + thought: str, + observation: str, + answer: str, + messages_ids: List[str], + llm_usage: LLMUsage = None) -> MessageAgentThought: + """ + Save agent thought + """ + if thought is not None: + agent_thought.thought = thought + + if tool_name is not None: + agent_thought.tool = tool_name + + if tool_input is not None: + if isinstance(tool_input, dict): + try: + tool_input = json.dumps(tool_input, ensure_ascii=False) + except Exception as e: + tool_input = json.dumps(tool_input) + + agent_thought.tool_input = tool_input + + if observation is not None: + agent_thought.observation = observation + + if answer is not None: + agent_thought.answer = answer + + if messages_ids is not None and len(messages_ids) > 0: + agent_thought.message_files = json.dumps(messages_ids) + + if llm_usage: + agent_thought.message_token = llm_usage.prompt_tokens + agent_thought.message_price_unit = llm_usage.prompt_price_unit + agent_thought.message_unit_price = llm_usage.prompt_unit_price + agent_thought.answer_token = llm_usage.completion_tokens + agent_thought.answer_price_unit = llm_usage.completion_price_unit + agent_thought.answer_unit_price = llm_usage.completion_unit_price + agent_thought.tokens = llm_usage.total_tokens + agent_thought.total_price = llm_usage.total_price + + db.session.commit() + + def get_history_prompt_messages(self) -> List[PromptMessage]: + """ + Get history prompt messages + """ + if self.history_prompt_messages is None: + self.history_prompt_messages = db.session.query(PromptMessage).filter( + PromptMessage.message_id == self.message.id, + ).order_by(PromptMessage.position.asc()).all() + + return self.history_prompt_messages + + def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]: + """ + Transform tool message into agent thought + """ + result = [] + + for message in messages: + if message.type == ToolInvokeMessage.MessageType.TEXT: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.LINK: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.IMAGE: + # try to download image + try: + file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id, + file_url=message.message) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' + + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + except Exception as e: + logger.exception(e) + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, you can try to download it yourself.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + )) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get mime type and save blob to storage + mimetype = message.meta.get('mime_type', 'octet/stream') + # if message is str, encode it to bytes + if isinstance(message.message, str): + message.message = message.message.encode('utf-8') + file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id, + file_binary=message.message, + mimetype=mimetype) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' + + # check if file is image + if 'image' in mimetype: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(message) + + return result + + def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): + """ + convert tool variables to db variables + """ + db_variables.updated_at = datetime.utcnow() + db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) + db.session.commit() \ No newline at end of file diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py new file mode 100644 index 0000000000..130833c7b3 --- /dev/null +++ b/api/core/features/assistant_cot_runner.py @@ -0,0 +1,578 @@ +import json +import logging +import re +from typing import Literal, Union, Generator, Dict, List + +from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit +from core.application_queue_manager import PublishFrom +from core.model_runtime.utils.encoders import jsonable_encoder +from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, \ + UserPromptMessage, SystemPromptMessage, AssistantPromptMessage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta +from core.model_manager import ModelInstance + +from core.tools.errors import ToolInvokeError, ToolNotFoundError, \ + ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \ + ToolProviderCredentialValidationError + +from core.features.assistant_base_runner import BaseAssistantApplicationRunner + +from models.model import Conversation, Message + +logger = logging.getLogger(__name__) + +class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): + def run(self, model_instance: ModelInstance, + conversation: Conversation, + message: Message, + query: str, + ) -> Union[Generator, LLMResult]: + """ + Run Cot agent application + """ + app_orchestration_config = self.app_orchestration_config + self._repacket_app_orchestration_config(app_orchestration_config) + + agent_scratchpad: List[AgentScratchpadUnit] = [] + + # check model mode + if self.app_orchestration_config.model_config.mode == "completion": + # TODO: stop words + if 'Observation' not in app_orchestration_config.model_config.stop: + app_orchestration_config.model_config.stop.append('Observation') + + iteration_step = 1 + max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 + + prompt_messages = self.history_prompt_messages + + # convert tools into ModelRuntime Tool format + prompt_messages_tools: List[PromptMessageTool] = [] + tool_instances = {} + for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: + try: + prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) + except Exception: + # api tool may be deleted + continue + # save tool entity + tool_instances[tool.tool_name] = tool_entity + # save prompt tool + prompt_messages_tools.append(prompt_tool) + + # convert dataset tools into ModelRuntime Tool format + for dataset_tool in self.dataset_tools: + prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) + # save prompt tool + prompt_messages_tools.append(prompt_tool) + # save tool entity + tool_instances[dataset_tool.identity.name] = dataset_tool + + function_call_state = True + llm_usage = { + 'usage': None + } + final_answer = '' + + def increse_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): + if not final_llm_usage_dict['usage']: + final_llm_usage_dict['usage'] = usage + else: + llm_usage = final_llm_usage_dict['usage'] + llm_usage.prompt_tokens += usage.prompt_tokens + llm_usage.completion_tokens += usage.completion_tokens + llm_usage.prompt_price += usage.prompt_price + llm_usage.completion_price += usage.completion_price + + while function_call_state and iteration_step <= max_iteration_steps: + # continue to run until there is not any tool call + function_call_state = False + + if iteration_step == max_iteration_steps: + # the last iteration, remove all tools + prompt_messages_tools = [] + + message_file_ids = [] + agent_thought = self.create_agent_thought( + message_id=message.id, + message='', + tool_name='', + tool_input='', + messages_ids=message_file_ids + ) + self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + + # update prompt messages + prompt_messages = self._originze_cot_prompt_messages( + mode=app_orchestration_config.model_config.mode, + prompt_messages=prompt_messages, + tools=prompt_messages_tools, + agent_scratchpad=agent_scratchpad, + agent_prompt_message=app_orchestration_config.agent.prompt, + instruction=app_orchestration_config.prompt_template.simple_prompt_template, + input=query + ) + + # recale llm max tokens + self.recale_llm_max_tokens(self.model_config, prompt_messages) + # invoke model + llm_result: LLMResult = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=app_orchestration_config.model_config.parameters, + tools=[], + stop=app_orchestration_config.model_config.stop, + stream=False, + user=self.user_id, + callbacks=[], + ) + + # check llm result + if not llm_result: + raise ValueError("failed to invoke llm") + + # get scratchpad + scratchpad = self._extract_response_scratchpad(llm_result.message.content) + agent_scratchpad.append(scratchpad) + + # get llm usage + if llm_result.usage: + increse_usage(llm_usage, llm_result.usage) + + self.save_agent_thought(agent_thought=agent_thought, + tool_name=scratchpad.action.action_name if scratchpad.action else '', + tool_input=scratchpad.action.action_input if scratchpad.action else '', + thought=scratchpad.thought, + observation='', + answer=llm_result.message.content, + messages_ids=[], + llm_usage=llm_result.usage) + + if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": + self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + + # publish agent thought if it's not empty and there is a action + if scratchpad.thought and scratchpad.action: + # check if final answer + if not scratchpad.action.action_name.lower() == "final answer": + yield LLMResultChunk( + model=model_instance.model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=scratchpad.thought + ), + usage=llm_result.usage, + ), + system_fingerprint='' + ) + + if not scratchpad.action: + # failed to extract action, return final answer directly + final_answer = scratchpad.agent_response or '' + else: + if scratchpad.action.action_name.lower() == "final answer": + # action is final answer, return final answer directly + try: + final_answer = scratchpad.action.action_input if \ + isinstance(scratchpad.action.action_input, str) else \ + json.dumps(scratchpad.action.action_input) + except json.JSONDecodeError: + final_answer = f'{scratchpad.action.action_input}' + else: + function_call_state = True + + # action is tool call, invoke tool + tool_call_name = scratchpad.action.action_name + tool_call_args = scratchpad.action.action_input + tool_instance = tool_instances.get(tool_call_name) + if not tool_instance: + logger.error(f"failed to find tool instance: {tool_call_name}") + answer = f"there is not a tool named {tool_call_name}" + self.save_agent_thought(agent_thought=agent_thought, + tool_name='', + tool_input='', + thought=None, + observation=answer, + answer=answer, + messages_ids=[]) + self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + else: + # invoke tool + error_response = None + try: + tool_response = tool_instance.invoke( + user_id=self.user_id, + tool_paramters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args) + ) + # transform tool response to llm friendly response + tool_response = self.transform_tool_invoke_messages(tool_response) + # extract binary data from tool invoke message + binary_files = self.extract_tool_response_binary(tool_response) + # create message file + message_files = self.create_message_files(binary_files) + # publish files + for message_file, save_as in message_files: + if save_as: + self.variables_pool.set_file(tool_name=tool_call_name, + value=message_file.id, + name=save_as) + self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) + + message_file_ids = [message_file.id for message_file, _ in message_files] + except ToolProviderCredentialValidationError as e: + error_response = f"Plese check your tool provider credentials" + except ( + ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError + ) as e: + error_response = f"there is not a tool named {tool_call_name}" + except ( + ToolParamterValidationError + ) as e: + error_response = f"tool paramters validation error: {e}, please check your tool paramters" + except ToolInvokeError as e: + error_response = f"tool invoke error: {e}" + except Exception as e: + error_response = f"unknown error: {e}" + + if error_response: + observation = error_response + logger.error(error_response) + else: + observation = self._convert_tool_response_to_str(tool_response) + + # save scratchpad + scratchpad.observation = observation + scratchpad.agent_response = llm_result.message.content + + # save agent thought + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=tool_call_name, + tool_input=tool_call_args, + thought=None, + observation=observation, + answer=llm_result.message.content, + messages_ids=message_file_ids, + ) + self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + + # update prompt tool message + for prompt_tool in prompt_messages_tools: + self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + + iteration_step += 1 + + yield LLMResultChunk( + model=model_instance.model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=final_answer + ), + usage=llm_usage['usage'] + ), + system_fingerprint='' + ) + + # save agent thought + self.save_agent_thought( + agent_thought=agent_thought, + tool_name='', + tool_input='', + thought=final_answer, + observation='', + answer=final_answer, + messages_ids=[] + ) + + self.update_db_variables(self.variables_pool, self.db_variables_pool) + # publish end event + self.queue_manager.publish_message_end(LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=final_answer + ), + usage=llm_usage['usage'], + system_fingerprint='' + ), PublishFrom.APPLICATION_MANAGER) + + def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit: + """ + extract response from llm response + """ + def extra_quotes() -> AgentScratchpadUnit: + agent_response = content + # try to extract all quotes + pattern = re.compile(r'```(.*?)```', re.DOTALL) + quotes = pattern.findall(content) + + # try to extract action from end to start + for i in range(len(quotes) - 1, 0, -1): + """ + 1. use json load to parse action + 2. use plain text `Action: xxx` to parse action + """ + try: + action = json.loads(quotes[i].replace('```', '')) + action_name = action.get("action") + action_input = action.get("action_input") + agent_thought = agent_response.replace(quotes[i], '') + + if action_name and action_input: + return AgentScratchpadUnit( + agent_response=content, + thought=agent_thought, + action_str=quotes[i], + action=AgentScratchpadUnit.Action( + action_name=action_name, + action_input=action_input, + ) + ) + except: + # try to parse action from plain text + action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE) + action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE) + # delete action from agent response + agent_thought = agent_response.replace(quotes[i], '') + # remove extra quotes + agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL) + # remove Action: xxx from agent thought + agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) + + if action_name and action_input: + return AgentScratchpadUnit( + agent_response=content, + thought=agent_thought, + action_str=quotes[i], + action=AgentScratchpadUnit.Action( + action_name=action_name[0], + action_input=action_input[0], + ) + ) + + def extra_json(): + agent_response = content + # try to extract all json + structures, pair_match_stack = [], [] + started_at, end_at = 0, 0 + for i in range(len(content)): + if content[i] == '{': + pair_match_stack.append(i) + if len(pair_match_stack) == 1: + started_at = i + elif content[i] == '}': + begin = pair_match_stack.pop() + if not pair_match_stack: + end_at = i + 1 + structures.append((content[begin:i+1], (started_at, end_at))) + + # handle the last character + if pair_match_stack: + end_at = len(content) + structures.append((content[pair_match_stack[0]:], (started_at, end_at))) + + for i in range(len(structures), 0, -1): + try: + json_content, (started_at, end_at) = structures[i - 1] + action = json.loads(json_content) + action_name = action.get("action") + action_input = action.get("action_input") + # delete json content from agent response + agent_thought = agent_response[:started_at] + agent_response[end_at:] + # remove extra quotes like ```(json)*\n\n``` + agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL) + # remove Action: xxx from agent thought + agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) + + if action_name and action_input: + return AgentScratchpadUnit( + agent_response=content, + thought=agent_thought, + action_str=json_content, + action=AgentScratchpadUnit.Action( + action_name=action_name, + action_input=action_input, + ) + ) + except: + pass + + agent_scratchpad = extra_quotes() + if agent_scratchpad: + return agent_scratchpad + agent_scratchpad = extra_json() + if agent_scratchpad: + return agent_scratchpad + + return AgentScratchpadUnit( + agent_response=content, + thought=content, + action_str='', + action=None + ) + + def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], + agent_prompt_message: AgentPromptEntity, + ): + """ + check chain of thought prompt messages, a standard prompt message is like: + Respond to the human as helpfully and accurately as possible. + + {{instruction}} + + You have access to the following tools: + + {{tools}} + + Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). + Valid action values: "Final Answer" or {{tool_names}} + + Provide only ONE action per $JSON_BLOB, as shown: + + ``` + { + "action": $TOOL_NAME, + "action_input": $ACTION_INPUT + } + ``` + """ + + # parse agent prompt message + first_prompt = agent_prompt_message.first_prompt + next_iteration = agent_prompt_message.next_iteration + + if not isinstance(first_prompt, str) or not isinstance(next_iteration, str): + raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode") + + # check instruction, tools, and tool_names slots + if not first_prompt.find("{{instruction}}") >= 0: + raise ValueError("{{instruction}} is required in first_prompt") + if not first_prompt.find("{{tools}}") >= 0: + raise ValueError("{{tools}} is required in first_prompt") + if not first_prompt.find("{{tool_names}}") >= 0: + raise ValueError("{{tool_names}} is required in first_prompt") + + if mode == "completion": + if not first_prompt.find("{{query}}") >= 0: + raise ValueError("{{query}} is required in first_prompt") + if not first_prompt.find("{{agent_scratchpad}}") >= 0: + raise ValueError("{{agent_scratchpad}} is required in first_prompt") + + if mode == "completion": + if not next_iteration.find("{{observation}}") >= 0: + raise ValueError("{{observation}} is required in next_iteration") + + def _convert_strachpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str: + """ + convert agent scratchpad list to str + """ + next_iteration = self.app_orchestration_config.agent.prompt.next_iteration + + result = '' + for scratchpad in agent_scratchpad: + result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation) + "\n" + + return result + + def _originze_cot_prompt_messages(self, mode: Literal["completion", "chat"], + prompt_messages: List[PromptMessage], + tools: List[PromptMessageTool], + agent_scratchpad: List[AgentScratchpadUnit], + agent_prompt_message: AgentPromptEntity, + instruction: str, + input: str, + ) -> List[PromptMessage]: + """ + originze chain of thought prompt messages, a standard prompt message is like: + Respond to the human as helpfully and accurately as possible. + + {{instruction}} + + You have access to the following tools: + + {{tools}} + + Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). + Valid action values: "Final Answer" or {{tool_names}} + + Provide only ONE action per $JSON_BLOB, as shown: + + ``` + {{{{ + "action": $TOOL_NAME, + "action_input": $ACTION_INPUT + }}}} + ``` + """ + + self._check_cot_prompt_messages(mode, agent_prompt_message) + + # parse agent prompt message + first_prompt = agent_prompt_message.first_prompt + + # parse tools + tools_str = self._jsonify_tool_prompt_messages(tools) + + # parse tools name + tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"' + + # get system message + system_message = first_prompt.replace("{{instruction}}", instruction) \ + .replace("{{tools}}", tools_str) \ + .replace("{{tool_names}}", tool_names) + + # originze prompt messages + if mode == "chat": + # override system message + overrided = False + prompt_messages = prompt_messages.copy() + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_message.content = system_message + overrided = True + break + + if not overrided: + prompt_messages.insert(0, SystemPromptMessage( + content=system_message, + )) + + # add assistant message + if len(agent_scratchpad) > 0: + prompt_messages.append(AssistantPromptMessage( + content=agent_scratchpad[-1].thought + "\n" + agent_scratchpad[-1].observation + )) + + # add user message + if len(agent_scratchpad) > 0: + prompt_messages.append(UserPromptMessage( + content=input, + )) + + return prompt_messages + elif mode == "completion": + # parse agent scratchpad + agent_scratchpad_str = self._convert_strachpad_list_to_str(agent_scratchpad) + # parse prompt messages + return [UserPromptMessage( + content=first_prompt.replace("{{instruction}}", instruction) + .replace("{{tools}}", tools_str) + .replace("{{tool_names}}", tool_names) + .replace("{{query}}", input) + .replace("{{agent_scratchpad}}", agent_scratchpad_str), + )] + else: + raise ValueError(f"mode {mode} is not supported") + + def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str: + """ + jsonify tool prompt messages + """ + tools = jsonable_encoder(tools) + try: + return json.dumps(tools, ensure_ascii=False) + except json.JSONDecodeError: + return json.dumps(tools) \ No newline at end of file diff --git a/api/core/features/assistant_fc_runner.py b/api/core/features/assistant_fc_runner.py new file mode 100644 index 0000000000..dfd59527d9 --- /dev/null +++ b/api/core/features/assistant_fc_runner.py @@ -0,0 +1,335 @@ +import json +import logging + +from typing import Union, Generator, Dict, Any, Tuple, List + +from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\ + SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool +from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage +from core.model_manager import ModelInstance +from core.application_queue_manager import PublishFrom + +from core.tools.errors import ToolInvokeError, ToolNotFoundError, \ + ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \ + ToolProviderCredentialValidationError + +from core.features.assistant_base_runner import BaseAssistantApplicationRunner + +from models.model import Conversation, Message, MessageAgentThought + +logger = logging.getLogger(__name__) + +class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): + def run(self, model_instance: ModelInstance, + conversation: Conversation, + message: Message, + query: str, + ) -> Generator[LLMResultChunk, None, None]: + """ + Run FunctionCall agent application + """ + app_orchestration_config = self.app_orchestration_config + + prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or '' + prompt_messages = self.history_prompt_messages + prompt_messages = self.organize_prompt_messages( + prompt_template=prompt_template, + query=query, + prompt_messages=prompt_messages + ) + + # convert tools into ModelRuntime Tool format + prompt_messages_tools: List[PromptMessageTool] = [] + tool_instances = {} + for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: + try: + prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) + except Exception: + # api tool may be deleted + continue + # save tool entity + tool_instances[tool.tool_name] = tool_entity + # save prompt tool + prompt_messages_tools.append(prompt_tool) + + # convert dataset tools into ModelRuntime Tool format + for dataset_tool in self.dataset_tools: + prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) + # save prompt tool + prompt_messages_tools.append(prompt_tool) + # save tool entity + tool_instances[dataset_tool.identity.name] = dataset_tool + + iteration_step = 1 + max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1 + + # continue to run until there is not any tool call + function_call_state = True + agent_thoughts: List[MessageAgentThought] = [] + llm_usage = { + 'usage': None + } + final_answer = '' + + def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): + if not final_llm_usage_dict['usage']: + final_llm_usage_dict['usage'] = usage + else: + llm_usage = final_llm_usage_dict['usage'] + llm_usage.prompt_tokens += usage.prompt_tokens + llm_usage.completion_tokens += usage.completion_tokens + llm_usage.prompt_price += usage.prompt_price + llm_usage.completion_price += usage.completion_price + + while function_call_state and iteration_step <= max_iteration_steps: + function_call_state = False + + if iteration_step == max_iteration_steps: + # the last iteration, remove all tools + prompt_messages_tools = [] + + message_file_ids = [] + agent_thought = self.create_agent_thought( + message_id=message.id, + message='', + tool_name='', + tool_input='', + messages_ids=message_file_ids + ) + self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + + # recale llm max tokens + self.recale_llm_max_tokens(self.model_config, prompt_messages) + # invoke model + chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=app_orchestration_config.model_config.parameters, + tools=prompt_messages_tools, + stop=app_orchestration_config.model_config.stop, + stream=True, + user=self.user_id, + callbacks=[], + ) + + tool_calls: List[Tuple[str, str, Dict[str, Any]]] = [] + + # save full response + response = '' + + # save tool call names and inputs + tool_call_names = '' + tool_call_inputs = '' + + current_llm_usage = None + + for chunk in chunks: + # check if there is any tool call + if self.check_tool_calls(chunk): + function_call_state = True + tool_calls.extend(self.extract_tool_calls(chunk)) + tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + try: + tool_call_inputs = json.dumps({ + tool_call[1]: tool_call[2] for tool_call in tool_calls + }, ensure_ascii=False) + except json.JSONDecodeError as e: + # ensure ascii to avoid encoding error + tool_call_inputs = json.dumps({ + tool_call[1]: tool_call[2] for tool_call in tool_calls + }) + + if chunk.delta.message and chunk.delta.message.content: + if isinstance(chunk.delta.message.content, list): + for content in chunk.delta.message.content: + response += content.data + else: + response += chunk.delta.message.content + + if chunk.delta.usage: + increase_usage(llm_usage, chunk.delta.usage) + current_llm_usage = chunk.delta.usage + + yield chunk + + # save thought + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=tool_call_names, + tool_input=tool_call_inputs, + thought=response, + observation=None, + answer=response, + messages_ids=[], + llm_usage=current_llm_usage + ) + + self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + + final_answer += response + '\n' + + # call tools + tool_responses = [] + for tool_call_id, tool_call_name, tool_call_args in tool_calls: + tool_instance = tool_instances.get(tool_call_name) + if not tool_instance: + logger.error(f"failed to find tool instance: {tool_call_name}") + tool_response = { + "tool_call_id": tool_call_id, + "tool_call_name": tool_call_name, + "tool_response": f"there is not a tool named {tool_call_name}" + } + tool_responses.append(tool_response) + else: + # invoke tool + error_response = None + try: + tool_invoke_message = tool_instance.invoke( + user_id=self.user_id, + tool_paramters=tool_call_args, + ) + # transform tool invoke message to get LLM friendly message + tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message) + # extract binary data from tool invoke message + binary_files = self.extract_tool_response_binary(tool_invoke_message) + # create message file + message_files = self.create_message_files(binary_files) + # publish files + for message_file, save_as in message_files: + if save_as: + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) + + # publish message file + self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) + # add message file ids + message_file_ids.append(message_file.id) + + except ToolProviderCredentialValidationError as e: + error_response = f"Plese check your tool provider credentials" + except ( + ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError + ) as e: + error_response = f"there is not a tool named {tool_call_name}" + except ( + ToolParamterValidationError + ) as e: + error_response = f"tool paramters validation error: {e}, please check your tool paramters" + except ToolInvokeError as e: + error_response = f"tool invoke error: {e}" + except Exception as e: + error_response = f"unknown error: {e}" + + if error_response: + observation = error_response + logger.error(error_response) + tool_response = { + "tool_call_id": tool_call_id, + "tool_call_name": tool_call_name, + "tool_response": error_response + } + tool_responses.append(tool_response) + else: + observation = self._convert_tool_response_to_str(tool_invoke_message) + tool_response = { + "tool_call_id": tool_call_id, + "tool_call_name": tool_call_name, + "tool_response": observation + } + tool_responses.append(tool_response) + + prompt_messages = self.organize_prompt_messages( + prompt_template=prompt_template, + query=None, + tool_call_id=tool_call_id, + tool_call_name=tool_call_name, + tool_response=tool_response['tool_response'], + prompt_messages=prompt_messages, + ) + + if len(tool_responses) > 0: + # save agent thought + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=None, + tool_input=None, + thought=None, + observation=tool_response['tool_response'], + answer=None, + messages_ids=message_file_ids + ) + self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) + + # update prompt messages + if response.strip(): + prompt_messages.append(AssistantPromptMessage( + content=response, + )) + + # update prompt tool + for prompt_tool in prompt_messages_tools: + self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + + iteration_step += 1 + + self.update_db_variables(self.variables_pool, self.db_variables_pool) + # publish end event + self.queue_manager.publish_message_end(LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=final_answer, + ), + usage=llm_usage['usage'], + system_fingerprint='' + ), PublishFrom.APPLICATION_MANAGER) + + def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: + """ + Check if there is any tool call in llm result chunk + """ + if llm_result_chunk.delta.message.tool_calls: + return True + return False + + def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: + """ + Extract tool calls from llm result chunk + + Returns: + List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] + """ + tool_calls = [] + for prompt_message in llm_result_chunk.delta.message.tool_calls: + tool_calls.append(( + prompt_message.id, + prompt_message.function.name, + json.loads(prompt_message.function.arguments), + )) + + return tool_calls + + def organize_prompt_messages(self, prompt_template: str, + query: str = None, + tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, + prompt_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: + """ + Organize prompt messages + """ + + if not prompt_messages: + prompt_messages = [ + SystemPromptMessage(content=prompt_template), + UserPromptMessage(content=query), + ] + else: + if tool_response: + prompt_messages = prompt_messages.copy() + prompt_messages.append( + ToolPromptMessage( + content=tool_response, + tool_call_id=tool_call_id, + name=tool_call_name, + ) + ) + + return prompt_messages \ No newline at end of file diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval.py index 90ca6c42ed..f8fcea7c10 100644 --- a/api/core/features/dataset_retrieval.py +++ b/api/core/features/dataset_retrieval.py @@ -6,8 +6,8 @@ from core.entities.application_entities import DatasetEntity, DatasetRetrieveCon from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool -from core.tool.dataset_retriever_tool import DatasetRetrieverTool +from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool +from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from langchain.tools import BaseTool from models.dataset import Dataset diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 3ebe531607..626dbbca43 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -22,6 +22,7 @@ class FileType(enum.Enum): class FileTransferMethod(enum.Enum): REMOTE_URL = 'remote_url' LOCAL_FILE = 'local_file' + TOOL_FILE = 'tool_file' @staticmethod def value_of(value): @@ -30,6 +31,16 @@ class FileTransferMethod(enum.Enum): return member raise ValueError(f"No matching enum found for value '{value}'") +class FileBelongsTo(enum.Enum): + USER = 'user' + ASSISTANT = 'assistant' + + @staticmethod + def value_of(value): + for member in FileBelongsTo: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") class FileObj(BaseModel): id: Optional[str] diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index e651745e29..e70a68e70b 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, Union import requests -from core.file.file_obj import FileObj, FileTransferMethod, FileType +from core.file.file_obj import FileObj, FileTransferMethod, FileType, FileBelongsTo from services.file_service import IMAGE_EXTENSIONS from extensions.ext_database import db from models.account import Account @@ -128,6 +128,9 @@ class MessageFileParser: # group by file type and convert file args or message files to FileObj for file in files: + if file.belongs_to == FileBelongsTo.ASSISTANT.value: + continue + file_obj = self._to_file_obj(file, file_upload_config) if file_obj.type not in type_file_objs: continue diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py new file mode 100644 index 0000000000..ea8605ac57 --- /dev/null +++ b/api/core/file/tool_file_parser.py @@ -0,0 +1,8 @@ +tool_file_manager = { + 'manager': None +} + +class ToolFileParser: + @staticmethod + def get_tool_file_manager() -> 'ToolFileManager': + return tool_file_manager['manager'] \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 4b6ba4be9c..f8fc5db99a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -485,19 +485,37 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :return: llm response chunk generator """ full_assistant_content = '' + delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \ + delta.delta.function_call is None: continue # assistant_message_tool_calls = delta.delta.tool_calls assistant_message_function_call = delta.delta.function_call # extract tool calls from response + if delta_assistant_message_function_call_storage is not None: + # handle process of stream function call + if assistant_message_function_call: + # message has not ended ever + delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments + continue + else: + # message has ended + assistant_message_function_call = delta_assistant_message_function_call_storage + delta_assistant_message_function_call_storage = None + else: + if assistant_message_function_call: + # start of stream function call + delta_assistant_message_function_call_storage = assistant_message_function_call + continue + # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) function_call = self._extract_response_function_call(assistant_message_function_call) tool_calls = [function_call] if function_call else [] diff --git a/api/core/tools/README.md b/api/core/tools/README.md new file mode 100644 index 0000000000..c7ee81422e --- /dev/null +++ b/api/core/tools/README.md @@ -0,0 +1,25 @@ +# Tools + +This module implements built-in tools used in Agent Assistants and Workflows within Dify. You could define and display your own tools in this module, without modifying the frontend logic. This decoupling allows for easier horizontal scaling of Dify's capabilities. + +## Feature Introduction + +The tools provided for Agents and Workflows are currently divided into two categories: +- `Built-in Tools` are internally implemented within our product and are hardcoded for use in Agents and Workflows. +- `Api-Based Tools` leverage third-party APIs for implementation. You don't need to code to integrate these -- simply provide interface definitions in formats like `OpenAPI` , `Swagger`, or the `OpenAI-plugin` on the front-end. + +### Built-in Tool Providers +![Alt text](docs/zh_Hans/images/index/image.png) + +### API Tool Providers +![Alt text](docs/zh_Hans/images/index/image-1.png) + +## Tool Integration + +To enable developers to build flexible and powerful tools, we provide two guides: + +### [Quick Integration 👈🏻](./docs/en_US/tool_scale_out.md) +Quick integration aims at quickly getting you up to speed with tool integration by walking over an example Google Search tool. + +### [Advanced Integration 👈🏻](./docs/en_US/advanced_scale_out.md) +Advanced integration will offer a deeper dive into the module interfaces, and explain how to implement more complex capabilities, such as generating images, combining multiple tools, and managing the flow of parameters, images, and files between different tools. \ No newline at end of file diff --git a/api/core/tools/README_CN.md b/api/core/tools/README_CN.md new file mode 100644 index 0000000000..fda5d0630c --- /dev/null +++ b/api/core/tools/README_CN.md @@ -0,0 +1,27 @@ +# Tools + +该模块提供了各Agent和Workflow中会使用的内置工具的调用、鉴权接口,并为 Dify 提供了统一的工具供应商的信息和凭据表单规则。 + +- 一方面将工具和业务代码解耦,方便开发者对模型横向扩展, +- 另一方面提供了只需在后端定义供应商和工具,即可在前端页面直接展示,无需修改前端逻辑。 + +## 功能介绍 + +对于给Agent和Workflow提供的工具,我们当前将其分为两类: +- `Built-in Tools` 内置工具,即Dify内部实现的工具,通过硬编码的方式提供给Agent和Workflow使用。 +- `Api-Based Tools` 基于API的工具,即通过调用第三方API实现的工具,`Api-Based Tool`不需要再额外定义,只需提供`OpenAPI` `Swagger` `OpenAI plugin`等接口文档即可。 + +### 内置工具供应商 +![Alt text](docs/zh_Hans/images/index/image.png) + +### API工具供应商 +![Alt text](docs/zh_Hans/images/index/image-1.png) + +## 工具接入 +为了实现更灵活更强大的功能,Tools提供了一系列的接口,帮助开发者快速构建想要的工具,本文作为开发者的入门指南,将会以[快速接入](./docs/zh_Hans/tool_scale_out.md)和[高级接入](./docs/zh_Hans/advanced_scale_out.md)两部分介绍如何接入工具。 + +### [快速接入 👈🏻](./docs/zh_Hans/tool_scale_out.md) +快速接入可以帮助你在10~20分钟内完成工具的接入,但是这种接入方式只能实现简单的功能,如果你想要实现更复杂的功能,可以参考下面的高级接入。 + +### [高级接入 👈🏻](./docs/zh_Hans/advanced_scale_out.md) +高级接入将介绍如何实现更复杂的功能配置,包括实现图生图、实现多个工具的组合、实现参数、图片、文件在多个工具之间的流转。 \ No newline at end of file diff --git a/api/core/tools/docs/en_US/advanced_scale_out.md b/api/core/tools/docs/en_US/advanced_scale_out.md new file mode 100644 index 0000000000..c6f516e1de --- /dev/null +++ b/api/core/tools/docs/en_US/advanced_scale_out.md @@ -0,0 +1,266 @@ +# Advanced Tool Integration + +Before starting with this advanced guide, please make sure you have a basic understanding of the tool integration process in Dify. Check out [Quick Integration](./tool_scale_out.md) for a quick runthrough. + +## Tool Interface + +We have defined a series of helper methods in the `Tool` class to help developers quickly build more complex tools. + +### Message Return + +Dify supports various message types such as `text`, `link`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces. + +Please note, some parameters in the following interfaces will be introduced in later sections. + +#### Image URL +You only need to pass the URL of the image, and Dify will automatically download the image and return it to the user. + +```python + def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + """ + create an image message + + :param image: the url of the image + :return: the image message + """ +``` + +#### Link +If you need to return a link, you can use the following interface. + +```python + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: + """ + create a link message + + :param link: the url of the link + :return: the link message + """ +``` + +#### Text +If you need to return a text message, you can use the following interface. + +```python + def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: + """ + create a text message + + :param text: the text of the message + :return: the text message + """ +``` + +#### File BLOB +If you need to return the raw data of a file, such as images, audio, video, PPT, Word, Excel, etc., you can use the following interface. + +- `blob` The raw data of the file, of bytes type +- `meta` The metadata of the file, if you know the type of the file, it is best to pass a `mime_type`, otherwise Dify will use `octet/stream` as the default type + +```python + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :return: the blob message + """ +``` + +### Shortcut Tools + +In large model applications, we have two common needs: +- First, summarize a long text in advance, and then pass the summary content to the LLM to prevent the original text from being too long for the LLM to handle +- The content obtained by the tool is a link, and the web page information needs to be crawled before it can be returned to the LLM + +To help developers quickly implement these two needs, we provide the following two shortcut tools. + +#### Text Summary Tool + +This tool takes in an user_id and the text to be summarized, and returns the summarized text. Dify will use the default model of the current workspace to summarize the long text. + +```python + def summary(self, user_id: str, content: str) -> str: + """ + summary the content + + :param user_id: the user id + :param content: the content + :return: the summary + """ +``` + +#### Web Page Crawling Tool + +This tool takes in web page link to be crawled and a user_agent (which can be empty), and returns a string containing the information of the web page. The `user_agent` is an optional parameter that can be used to identify the tool. If not passed, Dify will use the default `user_agent`. + +```python + def get_url(self, url: str, user_agent: str = None) -> str: + """ + get url + """ the crawled result +``` + +### Variable Pool + +We have introduced a variable pool in `Tool` to store variables, files, etc. generated during the tool's operation. These variables can be used by other tools during the tool's operation. + +Next, we will use `DallE3` and `Vectorizer.AI` as examples to introduce how to use the variable pool. + +- `DallE3` is an image generation tool that can generate images based on text. Here, we will let `DallE3` generate a logo for a coffee shop +- `Vectorizer.AI` is a vector image conversion tool that can convert images into vector images, so that the images can be infinitely enlarged without distortion. Here, we will convert the PNG icon generated by `DallE3` into a vector image, so that it can be truly used by designers. + +#### DallE3 +First, we use DallE3. After creating the image, we save the image to the variable pool. The code is as follows: + +```python +from typing import Any, Dict, List, Union +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from base64 import b64decode + +from openai import OpenAI + +class DallE3Tool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + client = OpenAI( + api_key=self.runtime.credentials['openai_api_key'], + ) + + # prompt + prompt = tool_paramters.get('prompt', '') + if not prompt: + return self.create_text_message('Please input prompt') + + # call openapi dalle3 + response = client.images.generate( + prompt=prompt, model='dall-e-3', + size='1024x1024', n=1, style='vivid', quality='standard', + response_format='b64_json' + ) + + result = [] + for image in response.data: + # Save all images to the variable pool through the save_as parameter. The variable name is self.VARIABLE_KEY.IMAGE.value. If new images are generated later, they will overwrite the previous images. + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={ 'mime_type': 'image/png' }, + save_as=self.VARIABLE_KEY.IMAGE.value)) + + return result +``` + +Note that we used `self.VARIABLE_KEY.IMAGE.value` as the variable name of the image. In order for developers' tools to cooperate with each other, we defined this `KEY`. You can use it freely, or you can choose not to use this `KEY`. Passing a custom KEY is also acceptable. + +#### Vectorizer.AI +Next, we use Vectorizer.AI to convert the PNG icon generated by DallE3 into a vector image. Let's go through the functions we defined here. The code is as follows: + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict, List, Union +from httpx import post +from base64 import b64decode + +class VectorizerTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool + """ + + + def get_runtime_parameters(self) -> List[ToolParamter]: + """ + Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list + """ + + + def is_tool_avaliable(self) -> bool: + """ + Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here + """ +``` + +Next, let's implement these three functions + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict, List, Union +from httpx import post +from base64 import b64decode + +class VectorizerTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key_name = self.runtime.credentials.get('api_key_name', None) + api_key_value = self.runtime.credentials.get('api_key_value', None) + + if not api_key_name or not api_key_value: + raise ToolProviderCredentialValidationError('Please input api key name and value') + + # Get image_id, the definition of image_id can be found in get_runtime_parameters + image_id = tool_paramters.get('image_id', '') + if not image_id: + return self.create_text_message('Please input image id') + + # Get the image generated by DallE from the variable pool + image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + if not image_binary: + return self.create_text_message('Image not found, please request user to generate image firstly.') + + # Generate vector image + response = post( + 'https://vectorizer.ai/api/v1/vectorize', + files={ 'image': image_binary }, + data={ 'mode': 'test' }, + auth=(api_key_name, api_key_value), + timeout=30 + ) + + if response.status_code != 200: + raise Exception(response.text) + + return [ + self.create_text_message('the vectorized svg is saved as an image.'), + self.create_blob_message(blob=response.content, + meta={'mime_type': 'image/svg+xml'}) + ] + + def get_runtime_parameters(self) -> List[ToolParamter]: + """ + override the runtime parameters + """ + # Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml. + return [ + ToolParamter.get_simple_instance( + name='image_id', + llm_description=f'the image id that you want to vectorize, \ + and the image id should be specified in \ + {[i.name for i in self.list_default_image_variables()]}', + type=ToolParamter.ToolParameterType.SELECT, + required=True, + options=[i.name for i in self.list_default_image_variables()] + ) + ] + + def is_tool_avaliable(self) -> bool: + # Only when there are images in the variable pool, the LLM needs to use this tool + return len(self.list_default_image_variables()) > 0 +``` + +It's worth noting that we didn't actually use `image_id` here. We assumed that there must be an image in the default variable pool when calling this tool, so we directly used `image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)` to get the image. In cases where the model's capabilities are weak, we recommend developers to do the same, which can effectively improve fault tolerance and avoid the model passing incorrect parameters. \ No newline at end of file diff --git a/api/core/tools/docs/en_US/tool_scale_out.md b/api/core/tools/docs/en_US/tool_scale_out.md new file mode 100644 index 0000000000..90aa38bfdd --- /dev/null +++ b/api/core/tools/docs/en_US/tool_scale_out.md @@ -0,0 +1,212 @@ +# Quick Tool Integration + +Here, we will use GoogleSearch as an example to demonstrate how to quickly integrate a tool. + +## 1. Prepare the Tool Provider yaml + +### Introduction +This yaml declares a new tool provider, and includes information like the provider's name, icon, author, and other details that are fetched by the frontend for display. + +### Example + +We need to create a `google` module (folder) under `core/tools/provider/builtin`, and create `google.yaml`. The name must be consistent with the module name. + +Subsequently, all operations related to this tool will be carried out under this module. + +```yaml +identity: # Basic information of the tool provider + author: Dify # Author + name: google # Name, unique, no duplication with other providers + label: # Label for frontend display + en_US: Google # English label + zh_Hans: Google # Chinese label + description: # Description for frontend display + en_US: Google # English description + zh_Hans: Google # Chinese description + icon: icon.svg # Icon, needs to be placed in the _assets folder of the current module + +``` + - The `identity` field is mandatory, it contains the basic information of the tool provider, including author, name, label, description, icon, etc. + - The icon needs to be placed in the `_assets` folder of the current module, you can refer to [here](../../provider/builtin/google/_assets/icon.svg). + +## 2. Prepare Provider Credentials + +Google, as a third-party tool, uses the API provided by SerpApi, which requires an API Key to use. This means that this tool needs a credential to use. For tools like `wikipedia`, there is no need to fill in the credential field, you can refer to [here](../../provider/builtin/wikipedia/wikipedia.yaml). + +After configuring the credential field, the effect is as follows: +```yaml +identity: + author: Dify + name: google + label: + en_US: Google + zh_Hans: Google + description: + en_US: Google + zh_Hans: Google + icon: icon.svg +credentails_for_provider: # Credential field + serpapi_api_key: # Credential field name + type: secret-input # Credential field type + required: true # Required or not + label: # Credential field label + en_US: SerpApi API key # English label + zh_Hans: SerpApi API key # Chinese label + placeholder: # Credential field placeholder + en_US: Please input your SerpApi API key # English placeholder + zh_Hans: 请输入你的 SerpApi API key # Chinese placeholder + help: # Credential field help text + en_US: Get your SerpApi API key from SerpApi # English help text + zh_Hans: 从 SerpApi 获取您的 SerpApi API key # Chinese help text + url: https://serpapi.com/manage-api-key # Credential field help link + +``` + +- `type`: Credential field type, currently can be either `secret-input`, `text-input`, or `select` , corresponding to password input box, text input box, and drop-down box, respectively. If set to `secret-input`, it will mask the input content on the frontend, and the backend will encrypt the input content. + +## 3. Prepare Tool yaml +A provider can have multiple tools, each tool needs a yaml file to describe, this file contains the basic information, parameters, output, etc. of the tool. + +Still taking GoogleSearch as an example, we need to create a `tools` module under the `google` module, and create `tools/google_search.yaml`, the content is as follows. + +```yaml +identity: # Basic information of the tool + name: google_search # Tool name, unique, no duplication with other tools + author: Dify # Author + label: # Label for frontend display + en_US: GoogleSearch # English label + zh_Hans: 谷歌搜索 # Chinese label +description: # Description for frontend display + human: # Introduction for frontend display, supports multiple languages + en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. + zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。 + llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # Introduction passed to LLM, in order to make LLM better understand this tool, we suggest to write as detailed information about this tool as possible here, so that LLM can understand and use this tool +parameters: # Parameter list + - name: query # Parameter name + type: string # Parameter type + required: true # Required or not + label: # Parameter label + en_US: Query string # English label + zh_Hans: 查询语句 # Chinese label + human_description: # Introduction for frontend display, supports multiple languages + en_US: used for searching + zh_Hans: 用于搜索网页内容 + llm_description: key words for searching # Introduction passed to LLM, similarly, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter + form: llm # Form type, llm means this parameter needs to be inferred by Agent, the frontend will not display this parameter + - name: result_type + type: select # Parameter type + required: true + options: # Drop-down box options + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: link + label: + en_US: link + zh_Hans: 链接 + default: link + label: + en_US: Result type + zh_Hans: 结果类型 + human_description: + en_US: used for selecting the result type, text or link + zh_Hans: 用于选择结果类型,使用文本还是链接进行展示 + form: form # Form type, form means this parameter needs to be filled in by the user on the frontend before the conversation starts + +``` + +- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc. +- `parameters` Parameter list + - `name` Parameter name, unique, no duplication with other parameters + - `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box + - `required` Required or not + - In `llm` mode, if the parameter is required, the Agent is required to infer this parameter + - In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts + - `options` Parameter options + - In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options + - In `form` mode, when `type` is `select`, the frontend will display these options + - `default` Default value + - `label` Parameter label, for frontend display + - `human_description` Introduction for frontend display, supports multiple languages + - `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter + - `form` Form type, currently supports `llm`, `form` two types, corresponding to Agent self-inference and frontend filling + +## 4. Add Tool Logic +After completing the tool configuration, we can start writing the tool code that defines how it is invoked. + +Create `google_search.py` under the `google/tools` module, the content is as follows. + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from typing import Any, Dict, List, Union + +class GoogleSearchTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + query = tool_paramters['query'] + result_type = tool_paramters['result_type'] + api_key = self.runtime.credentials['serpapi_api_key'] + # TODO: search with serpapi + result = SerpAPI(api_key).run(query, result_type=result_type) + + if result_type == 'text': + return self.create_text_message(text=result) + return self.create_link_message(link=result) +``` + +### Parameters +The overall logic of the tool is in the `_invoke` method, this method accepts two parameters: `user_id` and `tool_paramters`, which represent the user ID and tool parameters respectively + +### Return Data +When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message. + +## 5. Add Provider Code +Finally, we need to create a provider class under the provider module to implement the provider's credential verification logic. If the credential verification fails, it will throw a `ToolProviderCredentialValidationError` exception. + +Create `google.py` under the `google` module, the content is as follows. + +```python +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType +from core.tools.tool.tool import Tool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool + +from typing import Any, Dict + +class GoogleProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + # 1. Here you need to instantiate a GoogleSearchTool with GoogleSearchTool(), it will automatically load the yaml configuration of GoogleSearchTool, but at this time it does not have credential information inside + # 2. Then you need to use the fork_tool_runtime method to pass the current credential information to GoogleSearchTool + # 3. Finally, invoke it, the parameters need to be passed according to the parameter rules configured in the yaml of GoogleSearchTool + GoogleSearchTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "query": "test", + "result_type": "link" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) +``` + +## Completion +After the above steps are completed, we can see this tool on the frontend, and it can be used in the Agent. + +Of course, because google_search needs a credential, before using it, you also need to input your credentials on the frontend. + +![Alt text](../zh_Hans/images/index/image-2.png) \ No newline at end of file diff --git a/api/core/tools/docs/zh_Hans/advanced_scale_out.md b/api/core/tools/docs/zh_Hans/advanced_scale_out.md new file mode 100644 index 0000000000..520b48d289 --- /dev/null +++ b/api/core/tools/docs/zh_Hans/advanced_scale_out.md @@ -0,0 +1,266 @@ +# 高级接入Tool + +在开始高级接入之前,请确保你已经阅读过[快速接入](./tool_scale_out.md),并对Dify的工具接入流程有了基本的了解。 + +## 工具接口 + +我们在`Tool`类中定义了一系列快捷方法,用于帮助开发者快速构较为复杂的工具 + +### 消息返回 + +Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可以通过以下几个接口返回不同类型的消息给LLM和用户。 + +注意,在下面的接口中的部分参数将在后面的章节中介绍。 + +#### 图片URL +只需要传递图片的URL即可,Dify会自动下载图片并返回给用户。 + +```python + def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + """ + create an image message + + :param image: the url of the image + :return: the image message + """ +``` + +#### 链接 +如果你需要返回一个链接,可以使用以下接口。 + +```python + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: + """ + create a link message + + :param link: the url of the link + :return: the link message + """ +``` + +#### 文本 +如果你需要返回一个文本消息,可以使用以下接口。 + +```python + def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: + """ + create a text message + + :param text: the text of the message + :return: the text message + """ +``` + +#### 文件BLOB +如果你需要返回文件的原始数据,如图片、音频、视频、PPT、Word、Excel等,可以使用以下接口。 + +- `blob` 文件的原始数据,bytes类型 +- `meta` 文件的元数据,如果你知道该文件的类型,最好传递一个`mime_type`,否则Dify将使用`octet/stream`作为默认类型 + +```python + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :return: the blob message + """ +``` + +### 快捷工具 + +在大模型应用中,我们有两种常见的需求: +- 先将很长的文本进行提前总结,然后再将总结内容传递给LLM,以防止原文本过长导致LLM无法处理 +- 工具获取到的内容是一个链接,需要爬取网页信息后再返回给LLM + +为了帮助开发者快速实现这两种需求,我们提供了以下两个快捷工具。 + +#### 文本总结工具 + +该工具需要传入user_id和需要进行总结的文本,返回一个总结后的文本,Dify会使用当前工作空间的默认模型对长文本进行总结。 + +```python + def summary(self, user_id: str, content: str) -> str: + """ + summary the content + + :param user_id: the user id + :param content: the content + :return: the summary + """ +``` + +#### 网页爬取工具 + +该工具需要传入需要爬取的网页链接和一个user_agent(可为空),返回一个包含该网页信息的字符串,其中`user_agent`是可选参数,可以用来识别工具,如果不传递,Dify将使用默认的`user_agent`。 + +```python + def get_url(self, url: str, user_agent: str = None) -> str: + """ + get url + """ the crawled result +``` + +### 变量池 + +我们在`Tool`中引入了一个变量池,用于存储工具运行过程中产生的变量、文件等,这些变量可以在工具运行过程中被其他工具使用。 + +下面,我们以`DallE3`和`Vectorizer.AI`为例,介绍如何使用变量池。 + +- `DallE3`是一个图片生成工具,它可以根据文本生成图片,在这里,我们将让`DallE3`生成一个咖啡厅的Logo +- `Vectorizer.AI`是一个矢量图转换工具,它可以将图片转换为矢量图,使得图片可以无限放大而不失真,在这里,我们将`DallE3`生成的PNG图标转换为矢量图,从而可以真正被设计师使用。 + +#### DallE3 +首先我们使用DallE3,在创建完图片以后,我们将图片保存到变量池中,代码如下 + +```python +from typing import Any, Dict, List, Union +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from base64 import b64decode + +from openai import OpenAI + +class DallE3Tool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + client = OpenAI( + api_key=self.runtime.credentials['openai_api_key'], + ) + + # prompt + prompt = tool_paramters.get('prompt', '') + if not prompt: + return self.create_text_message('Please input prompt') + + # call openapi dalle3 + response = client.images.generate( + prompt=prompt, model='dall-e-3', + size='1024x1024', n=1, style='vivid', quality='standard', + response_format='b64_json' + ) + + result = [] + for image in response.data: + # 将所有图片通过save_as参数保存到变量池中,变量名为self.VARIABLE_KEY.IMAGE.value,如果如果后续有新的图片生成,那么将会覆盖之前的图片 + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={ 'mime_type': 'image/png' }, + save_as=self.VARIABLE_KEY.IMAGE.value)) + + return result +``` + +我们可以注意到这里我们使用了`self.VARIABLE_KEY.IMAGE.value`作为图片的变量名,为了便于开发者们的工具能够互相配合,我们定义了这个`KEY`,大家可以自由使用,也可以不使用这个`KEY`,传递一个自定义的KEY也是可以的。 + +#### Vectorizer.AI +接下来我们使用Vectorizer.AI,将DallE3生成的PNG图标转换为矢量图,我们先来过一遍我们在这里定义的函数,代码如下 + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict, List, Union +from httpx import post +from base64 import b64decode + +class VectorizerTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + 工具调用,图片变量名需要从这里传递进来,从而我们就可以从变量池中获取到图片 + """ + + + def get_runtime_parameters(self) -> List[ToolParamter]: + """ + 重写工具参数列表,我们可以根据当前变量池里的实际情况来动态生成参数列表,从而LLM可以根据参数列表来生成表单 + """ + + + def is_tool_avaliable(self) -> bool: + """ + 当前工具是否可用,如果当前变量池中没有图片,那么我们就不需要展示这个工具,这里返回False即可 + """ +``` + +接下来我们来实现这三个函数 + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict, List, Union +from httpx import post +from base64 import b64decode + +class VectorizerTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key_name = self.runtime.credentials.get('api_key_name', None) + api_key_value = self.runtime.credentials.get('api_key_value', None) + + if not api_key_name or not api_key_value: + raise ToolProviderCredentialValidationError('Please input api key name and value') + + # 获取image_id,image_id的定义可以在get_runtime_parameters中找到 + image_id = tool_paramters.get('image_id', '') + if not image_id: + return self.create_text_message('Please input image id') + + # 从变量池中获取到之前DallE生成的图片 + image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + if not image_binary: + return self.create_text_message('Image not found, please request user to generate image firstly.') + + # 生成矢量图 + response = post( + 'https://vectorizer.ai/api/v1/vectorize', + files={ 'image': image_binary }, + data={ 'mode': 'test' }, + auth=(api_key_name, api_key_value), + timeout=30 + ) + + if response.status_code != 200: + raise Exception(response.text) + + return [ + self.create_text_message('the vectorized svg is saved as an image.'), + self.create_blob_message(blob=response.content, + meta={'mime_type': 'image/svg+xml'}) + ] + + def get_runtime_parameters(self) -> List[ToolParamter]: + """ + override the runtime parameters + """ + # 这里,我们重写了工具参数列表,定义了image_id,并设置了它的选项列表为当前变量池中的所有图片,这里的配置与yaml中的配置是一致的 + return [ + ToolParamter.get_simple_instance( + name='image_id', + llm_description=f'the image id that you want to vectorize, \ + and the image id should be specified in \ + {[i.name for i in self.list_default_image_variables()]}', + type=ToolParamter.ToolParameterType.SELECT, + required=True, + options=[i.name for i in self.list_default_image_variables()] + ) + ] + + def is_tool_avaliable(self) -> bool: + # 只有当变量池中有图片时,LLM才需要使用这个工具 + return len(self.list_default_image_variables()) > 0 +``` + +可以注意到的是,我们这里其实并没有使用到`image_id`,我们已经假设了调用这个工具的时候一定有一张图片在默认的变量池中,所以直接使用了`image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)`来获取图片,在模型能力较弱的情况下,我们建议开发者们也这样做,可以有效提升容错率,避免模型传递错误的参数。 \ No newline at end of file diff --git a/api/core/tools/docs/zh_Hans/images/index/image-1.png b/api/core/tools/docs/zh_Hans/images/index/image-1.png new file mode 100644 index 0000000000..3bb146ec90 Binary files /dev/null and b/api/core/tools/docs/zh_Hans/images/index/image-1.png differ diff --git a/api/core/tools/docs/zh_Hans/images/index/image-2.png b/api/core/tools/docs/zh_Hans/images/index/image-2.png new file mode 100644 index 0000000000..9ddc4d5fb7 Binary files /dev/null and b/api/core/tools/docs/zh_Hans/images/index/image-2.png differ diff --git a/api/core/tools/docs/zh_Hans/images/index/image.png b/api/core/tools/docs/zh_Hans/images/index/image.png new file mode 100644 index 0000000000..f6ce3a6b62 Binary files /dev/null and b/api/core/tools/docs/zh_Hans/images/index/image.png differ diff --git a/api/core/tools/docs/zh_Hans/tool_scale_out.md b/api/core/tools/docs/zh_Hans/tool_scale_out.md new file mode 100644 index 0000000000..08746c70fc --- /dev/null +++ b/api/core/tools/docs/zh_Hans/tool_scale_out.md @@ -0,0 +1,212 @@ +# 快速接入Tool + +这里我们以GoogleSearch为例,介绍如何快速接入一个工具。 + +## 1. 准备工具供应商yaml + +### 介绍 +这个yaml将包含工具供应商的信息,包括供应商名称、图标、作者等详细信息,以帮助前端灵活展示。 + +### 示例 + +我们需要在 `core/tools/provider/builtin`下创建一个`google`模块(文件夹),并创建`google.yaml`,名称必须与模块名称一致。 + +后续,我们关于这个工具的所有操作都将在这个模块下进行。 + +```yaml +identity: # 工具供应商的基本信息 + author: Dify # 作者 + name: google # 名称,唯一,不允许和其他供应商重名 + label: # 标签,用于前端展示 + en_US: Google # 英文标签 + zh_Hans: Google # 中文标签 + description: # 描述,用于前端展示 + en_US: Google # 英文描述 + zh_Hans: Google # 中文描述 + icon: icon.svg # 图标,需要放置在当前模块的_assets文件夹下 + +``` + - `identity` 字段是必须的,它包含了工具供应商的基本信息,包括作者、名称、标签、描述、图标等 + - 图标需要放置在当前模块的`_assets`文件夹下,可以参考[这里](../../provider/builtin/google/_assets/icon.svg)。 + +## 2. 准备供应商凭据 + +Google作为一个第三方工具,使用了SerpApi提供的API,而SerpApi需要一个API Key才能使用,那么就意味着这个工具需要一个凭据才可以使用,而像`wikipedia`这样的工具,就不需要填写凭据字段,可以参考[这里](../../provider/builtin/wikipedia/wikipedia.yaml)。 + +配置好凭据字段后效果如下: +```yaml +identity: + author: Dify + name: google + label: + en_US: Google + zh_Hans: Google + description: + en_US: Google + zh_Hans: Google + icon: icon.svg +credentails_for_provider: # 凭据字段 + serpapi_api_key: # 凭据字段名称 + type: secret-input # 凭据字段类型 + required: true # 是否必填 + label: # 凭据字段标签 + en_US: SerpApi API key # 英文标签 + zh_Hans: SerpApi API key # 中文标签 + placeholder: # 凭据字段占位符 + en_US: Please input your SerpApi API key # 英文占位符 + zh_Hans: 请输入你的 SerpApi API key # 中文占位符 + help: # 凭据字段帮助文本 + en_US: Get your SerpApi API key from SerpApi # 英文帮助文本 + zh_Hans: 从 SerpApi 获取您的 SerpApi API key # 中文帮助文本 + url: https://serpapi.com/manage-api-key # 凭据字段帮助链接 + +``` + +- `type`:凭据字段类型,目前支持`secret-input`、`text-input`、`select` 三种类型,分别对应密码输入框、文本输入框、下拉框,如果为`secret-input`,则会在前端隐藏输入内容,并且后端会对输入内容进行加密。 + +## 3. 准备工具yaml +一个供应商底下可以有多个工具,每个工具都需要一个yaml文件来描述,这个文件包含了工具的基本信息、参数、输出等。 + +仍然以GoogleSearch为例,我们需要在`google`模块下创建一个`tools`模块,并创建`tools/google_search.yaml`,内容如下。 + +```yaml +identity: # 工具的基本信息 + name: google_search # 工具名称,唯一,不允许和其他工具重名 + author: Dify # 作者 + label: # 标签,用于前端展示 + en_US: GoogleSearch # 英文标签 + zh_Hans: 谷歌搜索 # 中文标签 +description: # 描述,用于前端展示 + human: # 用于前端展示的介绍,支持多语言 + en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. + zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。 + llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # 传递给LLM的介绍,为了使得LLM更好理解这个工具,我们建议在这里写上关于这个工具尽可能详细的信息,让LLM能够理解并使用这个工具 +parameters: # 参数列表 + - name: query # 参数名称 + type: string # 参数类型 + required: true # 是否必填 + label: # 参数标签 + en_US: Query string # 英文标签 + zh_Hans: 查询语句 # 中文标签 + human_description: # 用于前端展示的介绍,支持多语言 + en_US: used for searching + zh_Hans: 用于搜索网页内容 + llm_description: key words for searching # 传递给LLM的介绍,同上,为了使得LLM更好理解这个参数,我们建议在这里写上关于这个参数尽可能详细的信息,让LLM能够理解这个参数 + form: llm # 表单类型,llm表示这个参数需要由Agent自行推理出来,前端将不会展示这个参数 + - name: result_type + type: select # 参数类型 + required: true + options: # 下拉框选项 + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: link + label: + en_US: link + zh_Hans: 链接 + default: link + label: + en_US: Result type + zh_Hans: 结果类型 + human_description: + en_US: used for selecting the result type, text or link + zh_Hans: 用于选择结果类型,使用文本还是链接进行展示 + form: form # 表单类型,form表示这个参数需要由用户在对话开始前在前端填写 + +``` + +- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等 +- `parameters` 参数列表 + - `name` 参数名称,唯一,不允许和其他参数重名 + - `type` 参数类型,目前支持`string`、`number`、`boolean`、`select` 四种类型,分别对应字符串、数字、布尔值、下拉框 + - `required` 是否必填 + - 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数 + - 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数 + - `options` 参数选项 + - 在`llm`模式下,Dify会将所有选项传递给LLM,LLM可以根据这些选项进行推理 + - 在`form`模式下,`type`为`select`时,前端会展示这些选项 + - `default` 默认值 + - `label` 参数标签,用于前端展示 + - `human_description` 用于前端展示的介绍,支持多语言 + - `llm_description` 传递给LLM的介绍,为了使得LLM更好理解这个参数,我们建议在这里写上关于这个参数尽可能详细的信息,让LLM能够理解这个参数 + - `form` 表单类型,目前支持`llm`、`form`两种类型,分别对应Agent自行推理和前端填写 + +## 4. 准备工具代码 +当完成工具的配置以后,我们就可以开始编写工具代码了,主要用于实现工具的逻辑。 + +在`google/tools`模块下创建`google_search.py`,内容如下。 + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from typing import Any, Dict, List, Union + +class GoogleSearchTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + query = tool_paramters['query'] + result_type = tool_paramters['result_type'] + api_key = self.runtime.credentials['serpapi_api_key'] + # TODO: search with serpapi + result = SerpAPI(api_key).run(query, result_type=result_type) + + if result_type == 'text': + return self.create_text_message(text=result) + return self.create_link_message(link=result) +``` + +### 参数 +工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id`和`tool_paramters`,分别表示用户ID和工具参数 + +### 返回数据 +在工具返回时,你可以选择返回一个消息或者多个消息,这里我们返回一个消息,使用`create_text_message`和`create_link_message`可以创建一个文本消息或者一个链接消息。 + +## 5. 准备供应商代码 +最后,我们需要在供应商模块下创建一个供应商类,用于实现供应商的凭据验证逻辑,如果凭据验证失败,将会抛出`ToolProviderCredentialValidationError`异常。 + +在`google`模块下创建`google.py`,内容如下。 + +```python +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType +from core.tools.tool.tool import Tool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool + +from typing import Any, Dict + +class GoogleProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + # 1. 此处需要使用GoogleSearchTool()实例化一个GoogleSearchTool,它会自动加载GoogleSearchTool的yaml配置,但是此时它内部没有凭据信息 + # 2. 随后需要使用fork_tool_runtime方法,将当前的凭据信息传递给GoogleSearchTool + # 3. 最后invoke即可,参数需要根据GoogleSearchTool的yaml中配置的参数规则进行传递 + GoogleSearchTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "query": "test", + "result_type": "link" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) +``` + +## 完成 +当上述步骤完成以后,我们就可以在前端看到这个工具了,并且可以在Agent中使用这个工具。 + +当然,因为google_search需要一个凭据,在使用之前,还需要在前端配置它的凭据。 + +![Alt text](images/index/image-2.png) diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py new file mode 100644 index 0000000000..13c27b57ee --- /dev/null +++ b/api/core/tools/entities/common_entities.py @@ -0,0 +1,22 @@ +from typing import Optional + +from pydantic import BaseModel + + +class I18nObject(BaseModel): + """ + Model class for i18n object. + """ + zh_Hans: Optional[str] = None + en_US: str + + def __init__(self, **data): + super().__init__(**data) + if not self.zh_Hans: + self.zh_Hans = self.en_US + + def to_dict(self) -> dict: + return { + 'zh_Hans': self.zh_Hans, + 'en_US': self.en_US, + } \ No newline at end of file diff --git a/api/core/tools/entities/constant.py b/api/core/tools/entities/constant.py new file mode 100644 index 0000000000..2e75fedf99 --- /dev/null +++ b/api/core/tools/entities/constant.py @@ -0,0 +1,3 @@ +class DEFAULT_PROVIDERS: + API_BASED = '__api_based' + APP_BASED = '__app_based' \ No newline at end of file diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py new file mode 100644 index 0000000000..51ff11b9f8 --- /dev/null +++ b/api/core/tools/entities/tool_bundle.py @@ -0,0 +1,34 @@ +from pydantic import BaseModel +from typing import Dict, Optional, Any, List + +from core.tools.entities.tool_entities import ToolProviderType, ToolParamter + +class ApiBasedToolBundle(BaseModel): + """ + This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc. + """ + # server_url + server_url: str + # method + method: str + # summary + summary: Optional[str] = None + # operation_id + operation_id: str = None + # parameters + parameters: Optional[List[ToolParamter]] = None + # author + author: str + # icon + icon: Optional[str] = None + # openapi operation + openapi: dict + +class AppToolBundle(BaseModel): + """ + This class is used to store the schema information of an tool for an app. + """ + type: ToolProviderType + credential: Optional[Dict[str, Any]] = None + provider_id: str + tool_name: str \ No newline at end of file diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py new file mode 100644 index 0000000000..b17bce5401 --- /dev/null +++ b/api/core/tools/entities/tool_entities.py @@ -0,0 +1,305 @@ +from pydantic import BaseModel, Field +from enum import Enum +from typing import Optional, List, Dict, Any, Union, cast + +from core.tools.entities.common_entities import I18nObject + +class ToolProviderType(Enum): + """ + Enum class for tool provider + """ + BUILT_IN = "built-in" + APP_BASED = "app-based" + API_BASED = "api-based" + + @classmethod + def value_of(cls, value: str) -> 'ToolProviderType': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + +class ApiProviderSchemaType(Enum): + """ + Enum class for api provider schema type. + """ + OPENAPI = "openapi" + SWAGGER = "swagger" + OPENAI_PLUGIN = "openai_plugin" + OPENAI_ACTIONS = "openai_actions" + + @classmethod + def value_of(cls, value: str) -> 'ApiProviderSchemaType': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + +class ApiProviderAuthType(Enum): + """ + Enum class for api provider auth type. + """ + NONE = "none" + API_KEY = "api_key" + + @classmethod + def value_of(cls, value: str) -> 'ApiProviderAuthType': + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + +class ToolInvokeMessage(BaseModel): + class MessageType(Enum): + TEXT = "text" + IMAGE = "image" + LINK = "link" + BLOB = "blob" + IMAGE_LINK = "image_link" + + type: MessageType = MessageType.TEXT + """ + plain text, image url or link url + """ + message: Union[str, bytes] = None + meta: Dict[str, Any] = None + save_as: str = '' + +class ToolInvokeMessageBinary(BaseModel): + mimetype: str = Field(..., description="The mimetype of the binary") + url: str = Field(..., description="The url of the binary") + save_as: str = '' + +class ToolParamterOption(BaseModel): + value: str = Field(..., description="The value of the option") + label: I18nObject = Field(..., description="The label of the option") + +class ToolParamter(BaseModel): + class ToolParameterType(Enum): + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + SELECT = "select" + + class ToolParameterForm(Enum): + SCHEMA = "schema" # should be set while adding tool + FORM = "form" # should be set before invoking tool + LLM = "llm" # will be set by LLM + + name: str = Field(..., description="The name of the parameter") + label: I18nObject = Field(..., description="The label presented to the user") + human_description: I18nObject = Field(..., description="The description presented to the user") + type: ToolParameterType = Field(..., description="The type of the parameter") + form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") + llm_description: Optional[str] = None + required: Optional[bool] = False + default: Optional[str] = None + min: Optional[Union[float, int]] = None + max: Optional[Union[float, int]] = None + options: Optional[List[ToolParamterOption]] = None + + @classmethod + def get_simple_instance(cls, + name: str, llm_description: str, type: ToolParameterType, + required: bool, options: Optional[List[str]] = None) -> 'ToolParamter': + """ + get a simple tool parameter + + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param type: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter + """ + # convert options to ToolParamterOption + if options: + options = [ToolParamterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] + return cls( + name=name, + label=I18nObject(en_US='', zh_Hans=''), + human_description=I18nObject(en_US='', zh_Hans=''), + type=type, + form=cls.ToolParameterForm.LLM, + llm_description=llm_description, + required=required, + options=options, + ) + +class ToolProviderIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + description: I18nObject = Field(..., description="The description of the tool") + icon: str = Field(..., description="The icon of the tool") + label: I18nObject = Field(..., description="The label of the tool") + +class ToolDescription(BaseModel): + human: I18nObject = Field(..., description="The description presented to the user") + llm: str = Field(..., description="The description presented to the LLM") + +class ToolIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + +class ToolCredentialsOption(BaseModel): + value: str = Field(..., description="The value of the option") + label: I18nObject = Field(..., description="The label of the option") + +class ToolProviderCredentials(BaseModel): + class CredentialsType(Enum): + SECRET_INPUT = "secret-input" + TEXT_INPUT = "text-input" + SELECT = "select" + + @classmethod + def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType": + """ + Get value of given mode. + + :param value: mode value + :return: mode + """ + for mode in cls: + if mode.value == value: + return mode + raise ValueError(f'invalid mode value {value}') + + @staticmethod + def defaut(value: str) -> str: + return "" + + name: str = Field(..., description="The name of the credentials") + type: CredentialsType = Field(..., description="The type of the credentials") + required: bool = False + default: Optional[str] = None + options: Optional[List[ToolCredentialsOption]] = None + label: Optional[I18nObject] = None + help: Optional[I18nObject] = None + url: Optional[str] = None + placeholder: Optional[I18nObject] = None + + def to_dict(self) -> dict: + return { + 'name': self.name, + 'type': self.type.value, + 'required': self.required, + 'default': self.default, + 'options': self.options, + 'help': self.help.to_dict() if self.help else None, + 'label': self.label.to_dict(), + 'url': self.url, + 'placeholder': self.placeholder.to_dict() if self.placeholder else None, + } + +class ToolRuntimeVariableType(Enum): + TEXT = "text" + IMAGE = "image" + +class ToolRuntimeVariable(BaseModel): + type: ToolRuntimeVariableType = Field(..., description="The type of the variable") + name: str = Field(..., description="The name of the variable") + position: int = Field(..., description="The position of the variable") + tool_name: str = Field(..., description="The name of the tool") + +class ToolRuntimeTextVariable(ToolRuntimeVariable): + value: str = Field(..., description="The value of the variable") + +class ToolRuntimeImageVariable(ToolRuntimeVariable): + value: str = Field(..., description="The path of the image") + +class ToolRuntimeVariablePool(BaseModel): + conversation_id: str = Field(..., description="The conversation id") + user_id: str = Field(..., description="The user id") + tenant_id: str = Field(..., description="The tenant id of assistant") + + pool: List[ToolRuntimeVariable] = Field(..., description="The pool of variables") + + def __init__(self, **data: Any): + pool = data.get('pool', []) + # convert pool into correct type + for index, variable in enumerate(pool): + if variable['type'] == ToolRuntimeVariableType.TEXT.value: + pool[index] = ToolRuntimeTextVariable(**variable) + elif variable['type'] == ToolRuntimeVariableType.IMAGE.value: + pool[index] = ToolRuntimeImageVariable(**variable) + super().__init__(**data) + + def dict(self) -> dict: + return { + 'conversation_id': self.conversation_id, + 'user_id': self.user_id, + 'tenant_id': self.tenant_id, + 'pool': [variable.dict() for variable in self.pool], + } + + def set_text(self, tool_name: str, name: str, value: str) -> None: + """ + set a text variable + """ + for variable in self.pool: + if variable.name == name: + if variable.type == ToolRuntimeVariableType.TEXT: + variable = cast(ToolRuntimeTextVariable, variable) + variable.value = value + return + + variable = ToolRuntimeTextVariable( + type=ToolRuntimeVariableType.TEXT, + name=name, + position=len(self.pool), + tool_name=tool_name, + value=value, + ) + + self.pool.append(variable) + + def set_file(self, tool_name: str, value: str, name: str = None) -> None: + """ + set an image variable + + :param tool_name: the name of the tool + :param value: the id of the file + """ + # check how many image variables are there + image_variable_count = 0 + for variable in self.pool: + if variable.type == ToolRuntimeVariableType.IMAGE: + image_variable_count += 1 + + if name is None: + name = f"file_{image_variable_count}" + + for variable in self.pool: + if variable.name == name: + if variable.type == ToolRuntimeVariableType.IMAGE: + variable = cast(ToolRuntimeImageVariable, variable) + variable.value = value + return + + variable = ToolRuntimeImageVariable( + type=ToolRuntimeVariableType.IMAGE, + name=name, + position=len(self.pool), + tool_name=tool_name, + value=value, + ) + + self.pool.append(variable) \ No newline at end of file diff --git a/api/core/tools/entities/user_entities.py b/api/core/tools/entities/user_entities.py new file mode 100644 index 0000000000..ce4723f071 --- /dev/null +++ b/api/core/tools/entities/user_entities.py @@ -0,0 +1,48 @@ +from pydantic import BaseModel +from enum import Enum +from typing import List, Dict, Optional + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderCredentials +from core.tools.tool.tool import ToolParamter + +class UserToolProvider(BaseModel): + class ProviderType(Enum): + BUILTIN = "builtin" + APP = "app" + API = "api" + + id: str + author: str + name: str # identifier + description: I18nObject + icon: str + label: I18nObject # label + type: ProviderType + team_credentials: dict = None + is_team_authorization: bool = False + allow_delete: bool = True + + def to_dict(self) -> dict: + return { + 'id': self.id, + 'author': self.author, + 'name': self.name, + 'description': self.description.to_dict(), + 'icon': self.icon, + 'label': self.label.to_dict(), + 'type': self.type.value, + 'team_credentials': self.team_credentials, + 'is_team_authorization': self.is_team_authorization, + 'allow_delete': self.allow_delete + } + +class UserToolProviderCredentials(BaseModel): + credentails: Dict[str, ToolProviderCredentials] + +class UserTool(BaseModel): + author: str + name: str # identifier + label: I18nObject # label + description: I18nObject + parameters: Optional[List[ToolParamter]] \ No newline at end of file diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py new file mode 100644 index 0000000000..28fc241bb0 --- /dev/null +++ b/api/core/tools/errors.py @@ -0,0 +1,20 @@ +class ToolProviderNotFoundError(ValueError): + pass + +class ToolNotFoundError(ValueError): + pass + +class ToolParamterValidationError(ValueError): + pass + +class ToolProviderCredentialValidationError(ValueError): + pass + +class ToolNotSupportedError(ValueError): + pass + +class ToolInvokeError(ValueError): + pass + +class ToolApiSchemaError(ValueError): + pass \ No newline at end of file diff --git a/api/core/tools/model/errors.py b/api/core/tools/model/errors.py new file mode 100644 index 0000000000..6e242b349a --- /dev/null +++ b/api/core/tools/model/errors.py @@ -0,0 +1,2 @@ +class InvokeModelError(Exception): + pass \ No newline at end of file diff --git a/api/core/tools/model/tool_model_manager.py b/api/core/tools/model/tool_model_manager.py new file mode 100644 index 0000000000..0a8e131344 --- /dev/null +++ b/api/core/tools/model/tool_model_manager.py @@ -0,0 +1,174 @@ +""" + For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. + + Therefore, a model manager is needed to list/invoke/validate models. +""" + +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey +from core.model_runtime.errors.invoke import InvokeRateLimitError, InvokeBadRequestError, \ + InvokeConnectionError, InvokeAuthorizationError, InvokeServerUnavailableError +from core.model_runtime.utils.encoders import jsonable_encoder +from core.model_manager import ModelManager + +from core.tools.model.errors import InvokeModelError + +from extensions.ext_database import db + +from models.tools import ToolModelInvoke + +from typing import List, cast +import json + +class ToolModelManager: + @staticmethod + def get_max_llm_context_tokens( + tenant_id: str, + ) -> int: + """ + get max llm context tokens of the model + """ + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, model_type=ModelType.LLM, + ) + + if not model_instance: + raise InvokeModelError(f'Model not found') + + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + if not schema: + raise InvokeModelError(f'No model schema found') + + max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + if max_tokens is None: + return 2048 + + return max_tokens + + @staticmethod + def calculate_tokens( + tenant_id: str, + prompt_messages: List[PromptMessage] + ) -> int: + """ + calculate tokens from prompt messages and model parameters + """ + + # get model instance + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, model_type=ModelType.LLM + ) + + if not model_instance: + raise InvokeModelError(f'Model not found') + + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + + # get tokens + tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages) + + return tokens + + @staticmethod + def invoke( + user_id: str, tenant_id: str, + tool_type: str, tool_name: str, + prompt_messages: List[PromptMessage] + ) -> LLMResult: + """ + invoke model with parameters in user's own context + + :param user_id: user id + :param tenant_id: tenant id, the tenant id of the creator of the tool + :param tool_provider: tool provider + :param tool_id: tool id + :param tool_name: tool name + :param provider: model provider + :param model: model name + :param model_parameters: model parameters + :param prompt_messages: prompt messages + :return: AssistantPromptMessage + """ + + # get model manager + model_manager = ModelManager() + # get model instance + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, model_type=ModelType.LLM, + ) + + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + + # get model credentials + model_credentials = model_instance.credentials + + # get prompt tokens + prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages) + + model_parameters = { + 'temperature': 0.8, + 'top_p': 0.8, + } + + # create tool model invoke + tool_model_invoke = ToolModelInvoke( + user_id=user_id, + tenant_id=tenant_id, + provider=model_instance.provider, + tool_type=tool_type, + tool_name=tool_name, + model_parameters=json.dumps(model_parameters), + prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), + model_response='', + prompt_tokens=prompt_tokens, + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + provider_response_latency=0, + total_price=0, + currency='USD', + ) + + db.session.add(tool_model_invoke) + db.session.commit() + + try: + response: LLMResult = llm_model.invoke( + model=model_instance.model, + credentials=model_credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], stop=[], stream=False, user=user_id, callbacks=[] + ) + except InvokeRateLimitError as e: + raise InvokeModelError(f'Invoke rate limit error: {e}') + except InvokeBadRequestError as e: + raise InvokeModelError(f'Invoke bad request error: {e}') + except InvokeConnectionError as e: + raise InvokeModelError(f'Invoke connection error: {e}') + except InvokeAuthorizationError as e: + raise InvokeModelError(f'Invoke authorization error') + except InvokeServerUnavailableError as e: + raise InvokeModelError(f'Invoke server unavailable error: {e}') + except Exception as e: + raise InvokeModelError(f'Invoke error: {e}') + + # update tool model invoke + tool_model_invoke.model_response = response.message.content + if response.usage: + tool_model_invoke.answer_tokens = response.usage.completion_tokens + tool_model_invoke.answer_unit_price = response.usage.completion_unit_price + tool_model_invoke.answer_price_unit = response.usage.completion_price_unit + tool_model_invoke.provider_response_latency = response.usage.latency + tool_model_invoke.total_price = response.usage.total_price + tool_model_invoke.currency = response.usage.currency + + db.session.commit() + + return response \ No newline at end of file diff --git a/api/core/tools/prompt/template.py b/api/core/tools/prompt/template.py new file mode 100644 index 0000000000..3d35592279 --- /dev/null +++ b/api/core/tools/prompt/template.py @@ -0,0 +1,102 @@ +ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. + +{{instruction}} + +You have access to the following tools: + +{{tools}} + +Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). +Valid "action" values: "Final Answer" or {{tool_names}} + +Provide only ONE action per $JSON_BLOB, as shown: + +``` +{ + "action": $TOOL_NAME, + "action_input": $ACTION_INPUT +} +``` + +Follow this format: + +Question: input question to answer +Thought: consider previous and subsequent steps +Action: +``` +$JSON_BLOB +``` +Observation: action result +... (repeat Thought/Action/Observation N times) +Thought: I know what to respond +Action: +``` +{ + "action": "Final Answer", + "action_input": "Final response to human" +} +``` + +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +Question: {{query}} +Thought: {{agent_scratchpad}}""" + +ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} +Thought:""" + +ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. + +{{instruction}} + +You have access to the following tools: + +{{tools}} + +Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). +Valid "action" values: "Final Answer" or {{tool_names}} + +Provide only ONE action per $JSON_BLOB, as shown: + +``` +{ + "action": $TOOL_NAME, + "action_input": $ACTION_INPUT +} +``` + +Follow this format: + +Question: input question to answer +Thought: consider previous and subsequent steps +Action: +``` +$JSON_BLOB +``` +Observation: action result +... (repeat Thought/Action/Observation N times) +Thought: I know what to respond +Action: +``` +{ + "action": "Final Answer", + "action_input": "Final response to human" +} +``` + +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +""" + +ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" + +REACT_PROMPT_TEMPLATES = { + 'english': { + 'chat': { + 'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, + 'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES + }, + 'completion': { + 'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, + 'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES + } + } +} \ No newline at end of file diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py new file mode 100644 index 0000000000..1d1f09f715 --- /dev/null +++ b/api/core/tools/provider/api_tool_provider.py @@ -0,0 +1,169 @@ +from typing import Any, Dict, List +from core.tools.entities.tool_entities import ToolProviderType, ApiProviderAuthType, ToolProviderCredentials, ToolCredentialsOption +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.tool.tool import Tool +from core.tools.tool.api_tool import ApiTool +from core.tools.provider.tool_provider import ToolProviderController + +from extensions.ext_database import db + +from models.tools import ApiToolProvider + +class ApiBasedToolProviderController(ToolProviderController): + @staticmethod + def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController': + credentials_schema = { + 'auth_type': ToolProviderCredentials( + name='auth_type', + required=True, + type=ToolProviderCredentials.CredentialsType.SELECT, + options=[ + ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')), + ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) + ], + default='none', + help=I18nObject( + en_US='The auth type of the api provider', + zh_Hans='api provider 的认证类型' + ) + ) + } + if auth_type == ApiProviderAuthType.API_KEY: + credentials_schema = { + **credentials_schema, + 'api_key_header': ToolProviderCredentials( + name='api_key_header', + required=False, + default='api_key', + type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, + help=I18nObject( + en_US='The header name of the api key', + zh_Hans='携带 api key 的 header 名称' + ) + ), + 'api_key_value': ToolProviderCredentials( + name='api_key_value', + required=True, + type=ToolProviderCredentials.CredentialsType.SECRET_INPUT, + help=I18nObject( + en_US='The api key', + zh_Hans='api key的值' + ) + ) + } + elif auth_type == ApiProviderAuthType.NONE: + pass + else: + raise ValueError(f'invalid auth type {auth_type}') + + return ApiBasedToolProviderController(**{ + 'identity': { + 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', + 'name': db_provider.name, + 'label': { + 'en_US': db_provider.name, + 'zh_Hans': db_provider.name + }, + 'description': { + 'en_US': db_provider.description, + 'zh_Hans': db_provider.description + }, + 'icon': db_provider.icon + }, + 'credentials_schema': credentials_schema + }) + + @property + def app_type(self) -> ToolProviderType: + return ToolProviderType.API_BASED + + def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None: + pass + + def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None: + pass + + def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool: + """ + parse tool bundle to tool + + :param tool_bundle: the tool bundle + :return: the tool + """ + return ApiTool(**{ + 'api_bundle': tool_bundle, + 'identity' : { + 'author': tool_bundle.author, + 'name': tool_bundle.operation_id, + 'label': { + 'en_US': tool_bundle.operation_id, + 'zh_Hans': tool_bundle.operation_id + }, + 'icon': tool_bundle.icon if tool_bundle.icon else '' + }, + 'description': { + 'human': { + 'en_US': tool_bundle.summary or '', + 'zh_Hans': tool_bundle.summary or '' + }, + 'llm': tool_bundle.summary or '' + }, + 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], + }) + + def load_bundled_tools(self, tools: List[ApiBasedToolBundle]) -> List[ApiTool]: + """ + load bundled tools + + :param tools: the bundled tools + :return: the tools + """ + self.tools = [self._parse_tool_bundle(tool) for tool in tools] + + return self.tools + + def get_tools(self, user_id: str, tanent_id: str) -> List[ApiTool]: + """ + fetch tools from database + + :param user_id: the user id + :param tanent_id: the tanent id + :return: the tools + """ + if self.tools is not None: + return self.tools + + tools: List[Tool] = [] + + # get tanent api providers + db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tanent_id, + ApiToolProvider.name == self.identity.name + ).all() + + if db_providers and len(db_providers) != 0: + for db_provider in db_providers: + for tool in db_provider.tools: + assistant_tool = self._parse_tool_bundle(tool) + assistant_tool.is_team_authorization = True + tools.append(assistant_tool) + + self.tools = tools + return tools + + def get_tool(self, tool_name: str) -> ApiTool: + """ + get tool by name + + :param tool_name: the name of the tool + :return: the tool + """ + if self.tools is None: + self.get_tools() + + for tool in self.tools: + if tool.identity.name == tool_name: + return tool + + raise ValueError(f'tool {tool_name} not found') \ No newline at end of file diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py new file mode 100644 index 0000000000..20f6b55731 --- /dev/null +++ b/api/core/tools/provider/app_tool_provider.py @@ -0,0 +1,116 @@ +from typing import Any, Dict, List +from core.tools.entities.tool_entities import ToolProviderType, ToolParamter, ToolParamterOption +from core.tools.tool.tool import Tool +from core.tools.entities.common_entities import I18nObject +from core.tools.provider.tool_provider import ToolProviderController + +from extensions.ext_database import db +from models.tools import PublishedAppTool +from models.model import App, AppModelConfig + +import logging + +logger = logging.getLogger(__name__) + +class AppBasedToolProviderEntity(ToolProviderController): + @property + def app_type(self) -> ToolProviderType: + return ToolProviderType.APP_BASED + + def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None: + pass + + def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None: + pass + + def get_tools(self, user_id: str) -> List[Tool]: + db_tools: List[PublishedAppTool] = db.session.query(PublishedAppTool).filter( + PublishedAppTool.user_id == user_id, + ).all() + + if not db_tools or len(db_tools) == 0: + return [] + + tools: List[Tool] = [] + + for db_tool in db_tools: + tool = { + 'identity': { + 'author': db_tool.author, + 'name': db_tool.tool_name, + 'label': { + 'en_US': db_tool.tool_name, + 'zh_Hans': db_tool.tool_name + }, + 'icon': '' + }, + 'description': { + 'human': { + 'en_US': db_tool.description_i18n.en_US, + 'zh_Hans': db_tool.description_i18n.zh_Hans + }, + 'llm': db_tool.llm_description + }, + 'parameters': [] + } + # get app from db + app: App = db_tool.app + + if not app: + logger.error(f"app {db_tool.app_id} not found") + continue + + app_model_config: AppModelConfig = app.app_model_config + user_input_form_list = app_model_config.user_input_form_list + for input_form in user_input_form_list: + # get type + form_type = input_form.keys()[0] + default = input_form[form_type]['default'] + required = input_form[form_type]['required'] + label = input_form[form_type]['label'] + variable_name = input_form[form_type]['variable_name'] + options = input_form[form_type].get('options', []) + if form_type == 'paragraph' or form_type == 'text-input': + tool['parameters'].append(ToolParamter( + name=variable_name, + label=I18nObject( + en_US=label, + zh_Hans=label + ), + human_description=I18nObject( + en_US=label, + zh_Hans=label + ), + llm_description=label, + form=ToolParamter.ToolParameterForm.FORM, + type=ToolParamter.ToolParameterType.STRING, + required=required, + default=default + )) + elif form_type == 'select': + tool['parameters'].append(ToolParamter( + name=variable_name, + label=I18nObject( + en_US=label, + zh_Hans=label + ), + human_description=I18nObject( + en_US=label, + zh_Hans=label + ), + llm_description=label, + form=ToolParamter.ToolParameterForm.FORM, + type=ToolParamter.ToolParameterType.SELECT, + required=required, + default=default, + options=[ToolParamterOption( + value=option, + label=I18nObject( + en_US=option, + zh_Hans=option + ) + ) for option in options] + )) + + tools.append(Tool(**tool)) + return tools \ No newline at end of file diff --git a/api/core/tools/provider/builtin/__init__.py b/api/core/tools/provider/builtin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py new file mode 100644 index 0000000000..b7c2b80187 --- /dev/null +++ b/api/core/tools/provider/builtin/_positions.py @@ -0,0 +1,26 @@ +from core.tools.entities.user_entities import UserToolProvider +from typing import List + +position = { + 'google': 1, + 'wikipedia': 2, + 'dalle': 3, + 'webscraper': 4, + 'wolframalpha': 5, + 'chart': 6, + 'time': 7, + 'yahoo': 8, + 'stablediffusion': 9, + 'vectorizer': 10, + 'youtube': 11, +} + +class BuiltinToolProviderSort: + @staticmethod + def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]: + def sort_compare(provider: UserToolProvider) -> int: + return position.get(provider.name, 10000) + + sorted_providers = sorted(providers, key=sort_compare) + + return sorted_providers \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/_assets/icon.png b/api/core/tools/provider/builtin/chart/_assets/icon.png new file mode 100644 index 0000000000..878e56a051 Binary files /dev/null and b/api/core/tools/provider/builtin/chart/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py new file mode 100644 index 0000000000..53667f3099 --- /dev/null +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -0,0 +1,24 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.chart.tools.line import LinearChartTool + +import matplotlib.pyplot as plt +# use a business theme +plt.style.use('seaborn-v0_8-darkgrid') + +class ChartProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + LinearChartTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "data": "1,3,5,7,9,2,4,6,8,10", + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/chart.yaml b/api/core/tools/provider/builtin/chart/chart.yaml new file mode 100644 index 0000000000..9e953ec32a --- /dev/null +++ b/api/core/tools/provider/builtin/chart/chart.yaml @@ -0,0 +1,11 @@ +identity: + author: Dify + name: chart + label: + en_US: ChartGenerator + zh_Hans: 图表生成 + description: + en_US: Chart Generator is a tool for generating statistical charts like bar chart, line chart, pie chart, etc. + zh_Hans: 图表生成是一个用于生成可视化图表的工具,你可以通过它来生成柱状图、折线图、饼图等各类图表 + icon: icon.png +credentails_for_provider: diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py new file mode 100644 index 0000000000..fff60c54d6 --- /dev/null +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -0,0 +1,47 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage +import matplotlib.pyplot as plt +import io + +from typing import Any, Dict, List, Union + +class BarChartTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + data = tool_paramters.get('data', '') + if not data: + return self.create_text_message('Please input data') + data = data.split(';') + + # if all data is int, convert to int + if all([i.isdigit() for i in data]): + data = [int(i) for i in data] + else: + data = [float(i) for i in data] + + axis = tool_paramters.get('x_axis', None) or None + if axis: + axis = axis.split(';') + if len(axis) != len(data): + axis = None + + flg, ax = plt.subplots(figsize=(10, 8)) + + if axis: + axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha='right') + ax.bar(axis, data) + else: + ax.bar(range(len(data)), data) + + buf = io.BytesIO() + flg.savefig(buf, format='png') + buf.seek(0) + plt.close(flg) + + return [ + self.create_text_message('the bar chart is saved as an image.'), + self.create_blob_message(blob=buf.read(), + meta={'mime_type': 'image/png'}) + ] + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/bar.yaml b/api/core/tools/provider/builtin/chart/tools/bar.yaml new file mode 100644 index 0000000000..9e6787e938 --- /dev/null +++ b/api/core/tools/provider/builtin/chart/tools/bar.yaml @@ -0,0 +1,35 @@ +identity: + name: bar_chart + author: Dify + label: + en_US: Bar Chart + zh_Hans: 柱状图 + icon: icon.svg +description: + human: + en_US: Bar chart + zh_Hans: 柱状图 + llm: generate a bar chart with input data +parameters: + - name: data + type: string + required: true + label: + en_US: data + zh_Hans: 数据 + human_description: + en_US: data for generating bar chart + zh_Hans: 用于生成柱状图的数据 + llm_description: data for generating bar chart, data should be a string contains a list of numbers like "1;2;3;4;5" + form: llm + - name: x_axis + type: string + required: false + label: + en_US: X Axis + zh_Hans: x 轴 + human_description: + en_US: X axis for bar chart + zh_Hans: 柱状图的 x 轴 + llm_description: x axis for bar chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data + form: llm diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py new file mode 100644 index 0000000000..2b0a44250a --- /dev/null +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -0,0 +1,49 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage +import matplotlib.pyplot as plt +import io + +from typing import Any, Dict, List, Union + +class LinearChartTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + data = tool_paramters.get('data', '') + if not data: + return self.create_text_message('Please input data') + data = data.split(';') + + axis = tool_paramters.get('x_axis', None) or None + if axis: + axis = axis.split(';') + if len(axis) != len(data): + axis = None + + # if all data is int, convert to int + if all([i.isdigit() for i in data]): + data = [int(i) for i in data] + else: + data = [float(i) for i in data] + + flg, ax = plt.subplots(figsize=(10, 8)) + + if axis: + axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha='right') + ax.plot(axis, data) + else: + ax.plot(data) + + buf = io.BytesIO() + flg.savefig(buf, format='png') + buf.seek(0) + plt.close(flg) + + return [ + self.create_text_message('the linear chart is saved as an image.'), + self.create_blob_message(blob=buf.read(), + meta={'mime_type': 'image/png'}) + ] + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/line.yaml b/api/core/tools/provider/builtin/chart/tools/line.yaml new file mode 100644 index 0000000000..a4eb7affe9 --- /dev/null +++ b/api/core/tools/provider/builtin/chart/tools/line.yaml @@ -0,0 +1,35 @@ +identity: + name: line_chart + author: Dify + label: + en_US: Linear Chart + zh_Hans: 线性图表 + icon: icon.svg +description: + human: + en_US: linear chart + zh_Hans: 线性图表 + llm: generate a linear chart with input data +parameters: + - name: data + type: string + required: true + label: + en_US: data + zh_Hans: 数据 + human_description: + en_US: data for generating linear chart + zh_Hans: 用于生成线性图表的数据 + llm_description: data for generating linear chart, data should be a string contains a list of numbers like "1;2;3;4;5" + form: llm + - name: x_axis + type: string + required: false + label: + en_US: X Axis + zh_Hans: x 轴 + human_description: + en_US: X axis for linear chart + zh_Hans: 线性图表的 x 轴 + llm_description: x axis for linear chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data + form: llm diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py new file mode 100644 index 0000000000..5fec67df11 --- /dev/null +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -0,0 +1,46 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage +import matplotlib.pyplot as plt +import io + +from typing import Any, Dict, List, Union + +class PieChartTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + data = tool_paramters.get('data', '') + if not data: + return self.create_text_message('Please input data') + data = data.split(';') + categories = tool_paramters.get('categories', None) or None + + # if all data is int, convert to int + if all([i.isdigit() for i in data]): + data = [int(i) for i in data] + else: + data = [float(i) for i in data] + + flg, ax = plt.subplots() + + if categories: + categories = categories.split(';') + if len(categories) != len(data): + categories = None + + if categories: + ax.pie(data, labels=categories) + else: + ax.pie(data) + + buf = io.BytesIO() + flg.savefig(buf, format='png') + buf.seek(0) + plt.close(flg) + + return [ + self.create_text_message('the pie chart is saved as an image.'), + self.create_blob_message(blob=buf.read(), + meta={'mime_type': 'image/png'}) + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/pie.yaml b/api/core/tools/provider/builtin/chart/tools/pie.yaml new file mode 100644 index 0000000000..2e3506d20b --- /dev/null +++ b/api/core/tools/provider/builtin/chart/tools/pie.yaml @@ -0,0 +1,35 @@ +identity: + name: pie_chart + author: Dify + label: + en_US: Pie Chart + zh_Hans: 饼图 + icon: icon.svg +description: + human: + en_US: Pie chart + zh_Hans: 饼图 + llm: generate a pie chart with input data +parameters: + - name: data + type: string + required: true + label: + en_US: data + zh_Hans: 数据 + human_description: + en_US: data for generating pie chart + zh_Hans: 用于生成饼图的数据 + llm_description: data for generating pie chart, data should be a string contains a list of numbers like "1;2;3;4;5" + form: llm + - name: categories + type: string + required: true + label: + en_US: Categories + zh_Hans: 分类 + human_description: + en_US: Categories for pie chart + zh_Hans: 饼图的分类 + llm_description: categories for pie chart, categories should be a string contains a list of texts like "a;b;c;1;2" in order to match the data, each category should be split by ";" + form: llm diff --git a/api/core/tools/provider/builtin/dalle/__init__.py b/api/core/tools/provider/builtin/dalle/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/provider/builtin/dalle/_assets/icon.png b/api/core/tools/provider/builtin/dalle/_assets/icon.png new file mode 100644 index 0000000000..5155a73059 Binary files /dev/null and b/api/core/tools/provider/builtin/dalle/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py new file mode 100644 index 0000000000..b8dc0b5b82 --- /dev/null +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -0,0 +1,23 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict + +class DALLEProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + DallE2Tool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "prompt": "cute girl, blue eyes, white hair, anime style", + "size": "small", + "n": 1 + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/dalle.yaml b/api/core/tools/provider/builtin/dalle/dalle.yaml new file mode 100644 index 0000000000..54508ff806 --- /dev/null +++ b/api/core/tools/provider/builtin/dalle/dalle.yaml @@ -0,0 +1,47 @@ +identity: + author: Dify + name: dalle + label: + en_US: DALL-E + zh_Hans: DALL-E 绘画 + description: + en_US: DALL-E art + zh_Hans: DALL-E 绘画 + icon: icon.png +credentails_for_provider: + openai_api_key: + type: secret-input + required: true + label: + en_US: OpenAI API key + zh_Hans: OpenAI API key + help: + en_US: Please input your OpenAI API key + zh_Hans: 请输入你的 OpenAI API key + placeholder: + en_US: Please input your OpenAI API key + zh_Hans: 请输入你的 OpenAI API key + openai_organizaion_id: + type: text-input + required: false + label: + en_US: OpenAI organization ID + zh_Hans: OpenAI organization ID + help: + en_US: Please input your OpenAI organization ID + zh_Hans: 请输入你的 OpenAI organization ID + placeholder: + en_US: Please input your OpenAI organization ID + zh_Hans: 请输入你的 OpenAI organization ID + openai_base_url: + type: text-input + required: false + label: + en_US: OpenAI base URL + zh_Hans: OpenAI base URL + help: + en_US: Please input your OpenAI base URL + zh_Hans: 请输入你的 OpenAI base URL + placeholder: + en_US: Please input your OpenAI base URL + zh_Hans: 请输入你的 OpenAI base URL diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py new file mode 100644 index 0000000000..2a18e05aba --- /dev/null +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -0,0 +1,66 @@ +from typing import Any, Dict, List, Union +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from base64 import b64decode +from os.path import join + +from openai import OpenAI + +class DallE2Tool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + openai_organization = self.runtime.credentials.get('openai_organizaion_id', None) + if not openai_organization: + openai_organization = None + openai_base_url = self.runtime.credentials.get('openai_base_url', None) + if not openai_base_url: + openai_base_url = None + else: + openai_base_url = join(openai_base_url, 'v1') + + client = OpenAI( + api_key=self.runtime.credentials['openai_api_key'], + base_url=openai_base_url, + organization=openai_organization + ) + + SIZE_MAPPING = { + 'small': '256x256', + 'medium': '512x512', + 'large': '1024x1024', + } + + # prompt + prompt = tool_paramters.get('prompt', '') + if not prompt: + return self.create_text_message('Please input prompt') + + # get size + size = SIZE_MAPPING[tool_paramters.get('size', 'large')] + + # get n + n = tool_paramters.get('n', 1) + + # call openapi dalle2 + response = client.images.generate( + prompt=prompt, + model='dall-e-2', + size=size, + n=n, + response_format='b64_json' + ) + + result = [] + + for image in response.data: + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={ 'mime_type': 'image/png' }, + save_as=self.VARIABLE_KEY.IMAGE.value)) + + return result diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml new file mode 100644 index 0000000000..ebcaf02ebc --- /dev/null +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml @@ -0,0 +1,63 @@ +identity: + name: dalle2 + author: Dify + label: + en_US: DALL-E 2 + zh_Hans: DALL-E 2 绘画 + description: + en_US: DALL-E 2 is a powerful drawing tool that can draw the image you want based on your prompt + zh_Hans: DALL-E 2 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像 +description: + human: + en_US: DALL-E is a text to image tool + zh_Hans: DALL-E 是一个文本到图像的工具 + llm: DALL-E is a tool used to generate images from text +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Image prompt, you can check the official documentation of DallE 2 + zh_Hans: 图像提示词,您可以查看DallE 2 的官方文档 + llm_description: Image prompt of DallE 2, you should describe the image you want to generate as a list of words as possible as detailed + form: llm + - name: size + type: select + required: true + human_description: + en_US: used for selecting the image size + zh_Hans: 用于选择图像大小 + label: + en_US: Image size + zh_Hans: 图像大小 + form: form + options: + - value: small + label: + en_US: Small(256x256) + zh_Hans: 小(256x256) + - value: medium + label: + en_US: Medium(512x512) + zh_Hans: 中(512x512) + - value: large + label: + en_US: Large(1024x1024) + zh_Hans: 大(1024x1024) + default: large + - name: n + type: number + required: true + human_description: + en_US: used for selecting the number of images + zh_Hans: 用于选择图像数量 + label: + en_US: Number of images + zh_Hans: 图像数量 + form: form + default: 1 + min: 1 + max: 10 diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py new file mode 100644 index 0000000000..0b2d8df0ee --- /dev/null +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -0,0 +1,74 @@ +from typing import Any, Dict, List, Union +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from base64 import b64decode +from os.path import join + +from openai import OpenAI + +class DallE3Tool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + openai_organization = self.runtime.credentials.get('openai_organizaion_id', None) + if not openai_organization: + openai_organization = None + openai_base_url = self.runtime.credentials.get('openai_base_url', None) + if not openai_base_url: + openai_base_url = None + else: + openai_base_url = join(openai_base_url, 'v1') + + client = OpenAI( + api_key=self.runtime.credentials['openai_api_key'], + base_url=openai_base_url, + organization=openai_organization + ) + + SIZE_MAPPING = { + 'square': '1024x1024', + 'vertical': '1024x1792', + 'horizontal': '1792x1024', + } + + # prompt + prompt = tool_paramters.get('prompt', '') + if not prompt: + return self.create_text_message('Please input prompt') + # get size + size = SIZE_MAPPING[tool_paramters.get('size', 'square')] + # get n + n = tool_paramters.get('n', 1) + # get quality + quality = tool_paramters.get('quality', 'standard') + if quality not in ['standard', 'hd']: + return self.create_text_message('Invalid quality') + # get style + style = tool_paramters.get('style', 'vivid') + if style not in ['natural', 'vivid']: + return self.create_text_message('Invalid style') + + # call openapi dalle3 + response = client.images.generate( + prompt=prompt, + model='dall-e-3', + size=size, + n=n, + style=style, + quality=quality, + response_format='b64_json' + ) + + result = [] + + for image in response.data: + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={ 'mime_type': 'image/png' }, + save_as=self.VARIABLE_KEY.IMAGE.value)) + + return result diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml new file mode 100644 index 0000000000..0497a3274b --- /dev/null +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml @@ -0,0 +1,103 @@ +identity: + name: dalle3 + author: Dify + label: + en_US: DALL-E 3 + zh_Hans: DALL-E 3 绘画 + description: + en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources + zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源 +description: + human: + en_US: DALL-E is a text to image tool + zh_Hans: DALL-E 是一个文本到图像的工具 + llm: DALL-E is a tool used to generate images from text +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Image prompt, you can check the official documentation of DallE 3 + zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档 + llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed + form: llm + - name: size + type: select + required: true + human_description: + en_US: selecting the image size + zh_Hans: 选择图像大小 + label: + en_US: Image size + zh_Hans: 图像大小 + form: form + options: + - value: square + label: + en_US: Squre(1024x1024) + zh_Hans: 方(1024x1024) + - value: vertical + label: + en_US: Vertical(1024x1792) + zh_Hans: 竖屏(1024x1792) + - value: horizontal + label: + en_US: Horizontal(1792x1024) + zh_Hans: 横屏(1792x1024) + default: square + - name: n + type: number + required: true + human_description: + en_US: selecting the number of images + zh_Hans: 选择图像数量 + label: + en_US: Number of images + zh_Hans: 图像数量 + form: form + min: 1 + max: 1 + default: 1 + - name: quality + type: select + required: true + human_description: + en_US: selecting the image quality + zh_Hans: 选择图像质量 + label: + en_US: Image quality + zh_Hans: 图像质量 + form: form + options: + - value: standard + label: + en_US: Standard + zh_Hans: 标准 + - value: hd + label: + en_US: HD + zh_Hans: 高清 + default: standard + - name: style + type: select + required: true + human_description: + en_US: selecting the image style + zh_Hans: 选择图像风格 + label: + en_US: Image style + zh_Hans: 图像风格 + form: form + options: + - value: vivid + label: + en_US: Vivid + zh_Hans: 生动 + - value: natural + label: + en_US: Natural + zh_Hans: 自然 + default: vivid diff --git a/api/core/tools/provider/builtin/google/_assets/icon.svg b/api/core/tools/provider/builtin/google/_assets/icon.svg new file mode 100644 index 0000000000..bebbf52d3a --- /dev/null +++ b/api/core/tools/provider/builtin/google/_assets/icon.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py new file mode 100644 index 0000000000..9ce52f470f --- /dev/null +++ b/api/core/tools/provider/builtin/google/google.py @@ -0,0 +1,23 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool + +from typing import Any, Dict, List + +class GoogleProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + GoogleSearchTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "query": "test", + "result_type": "link" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/google.yaml b/api/core/tools/provider/builtin/google/google.yaml new file mode 100644 index 0000000000..3c0dd63a40 --- /dev/null +++ b/api/core/tools/provider/builtin/google/google.yaml @@ -0,0 +1,24 @@ +identity: + author: Dify + name: google + label: + en_US: Google + zh_Hans: Google + description: + en_US: Google + zh_Hans: GoogleSearch + icon: icon.svg +credentails_for_provider: + serpapi_api_key: + type: secret-input + required: true + label: + en_US: SerpApi API key + zh_Hans: SerpApi API key + placeholder: + en_US: Please input your SerpApi API key + zh_Hans: 请输入你的 SerpApi API key + help: + en_US: Get your SerpApi API key from SerpApi + zh_Hans: 从 SerpApi 获取您的 SerpApi API key + url: https://serpapi.com/manage-api-key diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py new file mode 100644 index 0000000000..c902f1477c --- /dev/null +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -0,0 +1,163 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from typing import Any, Dict, List, Union + +import os +import sys + +from serpapi import GoogleSearch + +class HiddenPrints: + """Context manager to hide prints.""" + + def __enter__(self) -> None: + """Open file to pipe stdout to.""" + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, *_: Any) -> None: + """Close file that stdout was piped to.""" + sys.stdout.close() + sys.stdout = self._original_stdout + + +class SerpAPI: + """ + SerpAPI tool provider. + """ + + search_engine: Any #: :meta private: + serpapi_api_key: str = None + + def __init__(self, api_key: str) -> None: + """Initialize SerpAPI tool provider.""" + self.serpapi_api_key = api_key + self.search_engine = GoogleSearch + + def run(self, query: str, **kwargs: Any) -> str: + """Run query through SerpAPI and parse result.""" + typ = kwargs.get("result_type", "text") + return self._process_response(self.results(query), typ=typ) + + def results(self, query: str) -> dict: + """Run query through SerpAPI and return the raw result.""" + params = self.get_params(query) + with HiddenPrints(): + search = self.search_engine(params) + res = search.get_dict() + return res + + def get_params(self, query: str) -> Dict[str, str]: + """Get parameters for SerpAPI.""" + _params = { + "api_key": self.serpapi_api_key, + "q": query, + } + params = { + "engine": "google", + "google_domain": "google.com", + "gl": "us", + "hl": "en", + **_params + } + return params + + @staticmethod + def _process_response(res: dict, typ: str) -> str: + """Process response from SerpAPI.""" + if "error" in res.keys(): + raise ValueError(f"Got error from SerpAPI: {res['error']}") + + if typ == "text": + if "answer_box" in res.keys() and type(res["answer_box"]) == list: + res["answer_box"] = res["answer_box"][0] + if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + toret = res["answer_box"]["answer"] + elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + toret = res["answer_box"]["snippet"] + elif ( + "answer_box" in res.keys() + and "snippet_highlighted_words" in res["answer_box"].keys() + ): + toret = res["answer_box"]["snippet_highlighted_words"][0] + elif ( + "sports_results" in res.keys() + and "game_spotlight" in res["sports_results"].keys() + ): + toret = res["sports_results"]["game_spotlight"] + elif ( + "shopping_results" in res.keys() + and "title" in res["shopping_results"][0].keys() + ): + toret = res["shopping_results"][:3] + elif ( + "knowledge_graph" in res.keys() + and "description" in res["knowledge_graph"].keys() + ): + toret = res["knowledge_graph"]["description"] + elif "snippet" in res["organic_results"][0].keys(): + toret = res["organic_results"][0]["snippet"] + elif "link" in res["organic_results"][0].keys(): + toret = res["organic_results"][0]["link"] + elif ( + "images_results" in res.keys() + and "thumbnail" in res["images_results"][0].keys() + ): + thumbnails = [item["thumbnail"] for item in res["images_results"][:10]] + toret = thumbnails + else: + toret = "No good search result found" + elif typ == "link": + if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \ + and "description_link" in res["knowledge_graph"].keys(): + toret = res["knowledge_graph"]["description_link"] + elif "knowledge_graph" in res.keys() and "see_results_about" in res["knowledge_graph"].keys() \ + and len(res["knowledge_graph"]["see_results_about"]) > 0: + see_result_about = res["knowledge_graph"]["see_results_about"] + toret = "" + for item in see_result_about: + if "name" not in item.keys() or "link" not in item.keys(): + continue + toret += f"[{item['name']}]({item['link']})\n" + elif "organic_results" in res.keys() and len(res["organic_results"]) > 0: + organic_results = res["organic_results"] + toret = "" + for item in organic_results: + if "title" not in item.keys() or "link" not in item.keys(): + continue + toret += f"[{item['title']}]({item['link']})\n" + elif "related_questions" in res.keys() and len(res["related_questions"]) > 0: + related_questions = res["related_questions"] + toret = "" + for item in related_questions: + if "question" not in item.keys() or "link" not in item.keys(): + continue + toret += f"[{item['question']}]({item['link']})\n" + elif "related_searches" in res.keys() and len(res["related_searches"]) > 0: + related_searches = res["related_searches"] + toret = "" + for item in related_searches: + if "query" not in item.keys() or "link" not in item.keys(): + continue + toret += f"[{item['query']}]({item['link']})\n" + else: + toret = "No good search result found" + return toret + +class GoogleSearchTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + query = tool_paramters['query'] + result_type = tool_paramters['result_type'] + api_key = self.runtime.credentials['serpapi_api_key'] + result = SerpAPI(api_key).run(query, result_type=result_type) + if result_type == 'text': + return self.create_text_message(text=result) + return self.create_link_message(link=result) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/tools/google_search.yaml b/api/core/tools/provider/builtin/google/tools/google_search.yaml new file mode 100644 index 0000000000..f3b2a96a2e --- /dev/null +++ b/api/core/tools/provider/builtin/google/tools/google_search.yaml @@ -0,0 +1,43 @@ +identity: + name: google_search + author: Dify + label: + en_US: GoogleSearch + zh_Hans: 谷歌搜索 +description: + human: + en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. + zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。 + llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. +parameters: + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + human_description: + en_US: used for searching + zh_Hans: 用于搜索网页内容 + llm_description: key words for searching + form: llm + - name: result_type + type: select + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: link + label: + en_US: link + zh_Hans: 链接 + default: link + label: + en_US: Result type + zh_Hans: 结果类型 + human_description: + en_US: used for selecting the result type, text or link + zh_Hans: 用于选择结果类型,使用文本还是链接进行展示 + form: form diff --git a/api/core/tools/provider/builtin/stablediffusion/_assets/icon.png b/api/core/tools/provider/builtin/stablediffusion/_assets/icon.png new file mode 100644 index 0000000000..fc372b28f1 Binary files /dev/null and b/api/core/tools/provider/builtin/stablediffusion/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py new file mode 100644 index 0000000000..ea18349ae6 --- /dev/null +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -0,0 +1,26 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import StableDiffusionTool + +from typing import Any, Dict + +class StableDiffusionProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + StableDiffusionTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "prompt": "cat", + "lora": "", + "steps": 1, + "width": 512, + "height": 512, + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml new file mode 100644 index 0000000000..8bb3a3c0d6 --- /dev/null +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.yaml @@ -0,0 +1,29 @@ +identity: + author: Dify + name: stablediffusion + label: + en_US: Stable Diffusion + zh_Hans: Stable Diffusion + description: + en_US: Stable Diffusion is a tool for generating images which can be deployed locally. + zh_Hans: Stable Diffusion 是一个可以在本地部署的图片生成的工具。 + icon: icon.png +credentails_for_provider: + base_url: + type: secret-input + required: true + label: + en_US: Base URL + zh_Hans: StableDiffusion服务器的Base URL + placeholder: + en_US: Please input your StableDiffusion server's Base URL + zh_Hans: 请输入你的 StableDiffusion 服务器的 Base URL + model: + type: text-input + required: true + label: + en_US: Model + zh_Hans: 模型 + placeholder: + en_US: Please input your model + zh_Hans: 请输入你的模型名称 diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py new file mode 100644 index 0000000000..02fc7448e0 --- /dev/null +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -0,0 +1,244 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter, ToolParamterOption +from core.tools.entities.common_entities import I18nObject +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict, List, Union +from httpx import post +from os.path import join +from base64 import b64decode, b64encode +from PIL import Image + +import json +import io + +from copy import deepcopy + +DRAW_TEXT_OPTIONS = { + "prompt": "", + "negative_prompt": "", + "seed": -1, + "subseed": -1, + "subseed_strength": 0, + "seed_resize_from_h": -1, + 'sampler_index': 'DPM++ SDE Karras', + "seed_resize_from_w": -1, + "batch_size": 1, + "n_iter": 1, + "steps": 10, + "cfg_scale": 7, + "width": 1024, + "height": 1024, + "restore_faces": False, + "do_not_save_samples": False, + "do_not_save_grid": False, + "eta": 0, + "denoising_strength": 0, + "s_min_uncond": 0, + "s_churn": 0, + "s_tmax": 0, + "s_tmin": 0, + "s_noise": 0, + "override_settings": {}, + "override_settings_restore_afterwards": True, + "refiner_switch_at": 0, + "disable_extra_networks": False, + "comments": {}, + "enable_hr": False, + "firstphase_width": 0, + "firstphase_height": 0, + "hr_scale": 2, + "hr_second_pass_steps": 0, + "hr_resize_x": 0, + "hr_resize_y": 0, + "hr_prompt": "", + "hr_negative_prompt": "", + "script_args": [], + "send_images": True, + "save_images": False, + "alwayson_scripts": {} +} + +class StableDiffusionTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + # base url + base_url = self.runtime.credentials.get('base_url', None) + if not base_url: + return self.create_text_message('Please input base_url') + model = self.runtime.credentials.get('model', None) + if not model: + return self.create_text_message('Please input model') + + # set model + try: + url = join(base_url, 'sdapi/v1/options') + response = post(url, data=json.dumps({ + 'sd_model_checkpoint': model + })) + if response.status_code != 200: + raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + except Exception as e: + raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + + + # prompt + prompt = tool_paramters.get('prompt', '') + if not prompt: + return self.create_text_message('Please input prompt') + + # get negative prompt + negative_prompt = tool_paramters.get('negative_prompt', '') + + # get size + width = tool_paramters.get('width', 1024) + height = tool_paramters.get('height', 1024) + + # get steps + steps = tool_paramters.get('steps', 1) + + # get lora + lora = tool_paramters.get('lora', '') + + # get image id + image_id = tool_paramters.get('image_id', '') + if image_id.strip(): + image_variable = self.get_default_image_variable() + if image_variable: + image_binary = self.get_variable_file(image_variable.name) + if not image_binary: + return self.create_text_message('Image not found, please request user to generate image firstly.') + + # convert image to RGB + image = Image.open(io.BytesIO(image_binary)) + image = image.convert("RGB") + buffer = io.BytesIO() + image.save(buffer, format="PNG") + image_binary = buffer.getvalue() + image.close() + + return self.img2img(base_url=base_url, + lora=lora, + image_binary=image_binary, + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + steps=steps) + + return self.text2img(base_url=base_url, + lora=lora, + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + steps=steps) + + def img2img(self, base_url: str, lora: str, image_binary: bytes, + prompt: str, negative_prompt: str, + width: int, height: int, steps: int) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + generate image + """ + draw_options = { + "init_images": [b64encode(image_binary).decode('utf-8')], + "prompt": "", + "negative_prompt": negative_prompt, + "denoising_strength": 0.9, + "width": width, + "height": height, + "cfg_scale": 7, + "sampler_name": "Euler a", + "restore_faces": False, + "steps": steps, + "script_args": ["outpainting mk2"] + } + + if lora: + draw_options['prompt'] = f'{lora},{prompt}' + else: + draw_options['prompt'] = prompt + + try: + url = join(base_url, 'sdapi/v1/img2img') + response = post(url, data=json.dumps(draw_options), timeout=120) + if response.status_code != 200: + return self.create_text_message('Failed to generate image') + + image = response.json()['images'][0] + + return self.create_blob_message(blob=b64decode(image), + meta={ 'mime_type': 'image/png' }, + save_as=self.VARIABLE_KEY.IMAGE.value) + + except Exception as e: + return self.create_text_message('Failed to generate image') + + def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + generate image + """ + # copy draw options + draw_options = deepcopy(DRAW_TEXT_OPTIONS) + + if lora: + draw_options['prompt'] = f'{lora},{prompt}' + + draw_options['width'] = width + draw_options['height'] = height + draw_options['steps'] = steps + draw_options['negative_prompt'] = negative_prompt + + try: + url = join(base_url, 'sdapi/v1/txt2img') + response = post(url, data=json.dumps(draw_options), timeout=120) + if response.status_code != 200: + return self.create_text_message('Failed to generate image') + + image = response.json()['images'][0] + + return self.create_blob_message(blob=b64decode(image), + meta={ 'mime_type': 'image/png' }, + save_as=self.VARIABLE_KEY.IMAGE.value) + + except Exception as e: + return self.create_text_message('Failed to generate image') + + + def get_runtime_parameters(self) -> List[ToolParamter]: + parameters = [ + ToolParamter(name='prompt', + label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), + human_description=I18nObject( + en_US='Image prompt, you can check the official documentation of Stable Diffusion', + zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档', + ), + type=ToolParamter.ToolParameterType.STRING, + form=ToolParamter.ToolParameterForm.LLM, + llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.', + required=True), + ] + if len(self.list_default_image_variables()) != 0: + parameters.append( + ToolParamter(name='image_id', + label=I18nObject(en_US='image_id', zh_Hans='image_id'), + human_description=I18nObject( + en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.', + zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。', + ), + type=ToolParamter.ToolParameterType.STRING, + form=ToolParamter.ToolParameterForm.LLM, + llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.', + required=True, + options=[ToolParamterOption( + value=i.name, + label=I18nObject(en_US=i.name, zh_Hans=i.name) + ) for i in self.list_default_image_variables()]) + ) + + return parameters diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.yaml new file mode 100644 index 0000000000..cd20a81c15 --- /dev/null +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.yaml @@ -0,0 +1,77 @@ +identity: + name: stable_diffusion + author: Dify + label: + en_US: Stable Diffusion WebUI + zh_Hans: Stable Diffusion WebUI +description: + human: + en_US: A tool for generating images which can be deployed locally, you can use stable-diffusion-webui to deploy it. + zh_Hans: 一个可以在本地部署的图片生成的工具,您可以使用 stable-diffusion-webui 来部署它。 + llm: draw the image you want based on your prompt. +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Image prompt, you can check the official documentation of Stable Diffusion + zh_Hans: 图像提示词,您可以查看 Stable Diffusion 的官方文档 + llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: lora + type: string + required: false + label: + en_US: Lora + zh_Hans: Lora + human_description: + en_US: Lora + zh_Hans: Lora + form: form + - name: steps + type: number + required: false + label: + en_US: Steps + zh_Hans: Steps + human_description: + en_US: Steps + zh_Hans: Steps + form: form + default: 10 + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: Width + human_description: + en_US: Width + zh_Hans: Width + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: Height + human_description: + en_US: Height + zh_Hans: Height + form: form + default: 1024 + - name: negative_prompt + type: string + required: false + label: + en_US: Negative prompt + zh_Hans: Negative prompt + human_description: + en_US: Negative prompt + zh_Hans: Negative prompt + form: form + default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines diff --git a/api/core/tools/provider/builtin/time/_assets/icon.svg b/api/core/tools/provider/builtin/time/_assets/icon.svg new file mode 100644 index 0000000000..6d7118aed9 --- /dev/null +++ b/api/core/tools/provider/builtin/time/_assets/icon.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py new file mode 100644 index 0000000000..24fd287d10 --- /dev/null +++ b/api/core/tools/provider/builtin/time/time.py @@ -0,0 +1,16 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool + +from typing import Any, Dict + +class WikiPediaProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + CurrentTimeTool().invoke( + user_id='', + tool_paramters={}, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/time.yaml b/api/core/tools/provider/builtin/time/time.yaml new file mode 100644 index 0000000000..5a2c3395fa --- /dev/null +++ b/api/core/tools/provider/builtin/time/time.yaml @@ -0,0 +1,11 @@ +identity: + author: Dify + name: time + label: + en_US: CurrentTime + zh_Hans: 时间 + description: + en_US: A tool for getting the current time. + zh_Hans: 一个用于获取当前时间的工具。 + icon: icon.svg +credentails_for_provider: diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py new file mode 100644 index 0000000000..ce380cd5bf --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -0,0 +1,17 @@ +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from typing import Any, Dict, List, Union + +from datetime import datetime, timezone + +class CurrentTimeTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + return self.create_text_message(f'{datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")}') + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/tools/current_time.yaml b/api/core/tools/provider/builtin/time/tools/current_time.yaml new file mode 100644 index 0000000000..2d86e4db4b --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/current_time.yaml @@ -0,0 +1,12 @@ +identity: + name: current_time + author: Dify + label: + en_US: Current Time + zh_Hans: 获取当前时间 +description: + human: + en_US: A tool for getting the current time. + zh_Hans: 一个用于获取当前时间的工具。 + llm: A tool for getting the current time. +parameters: diff --git a/api/core/tools/provider/builtin/vectorizer/_assets/icon.png b/api/core/tools/provider/builtin/vectorizer/_assets/icon.png new file mode 100644 index 0000000000..52f18db843 Binary files /dev/null and b/api/core/tools/provider/builtin/vectorizer/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py new file mode 100644 index 0000000000..1506ac0c9d --- /dev/null +++ b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py @@ -0,0 +1 @@ +VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC' \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py new file mode 100644 index 0000000000..d6d70b345a --- /dev/null +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -0,0 +1,74 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter +from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict, List, Union +from httpx import post +from base64 import b64decode + +class VectorizerTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key_name = self.runtime.credentials.get('api_key_name', None) + api_key_value = self.runtime.credentials.get('api_key_value', None) + mode = tool_paramters.get('mode', 'test') + if mode == 'production': + mode = 'preview' + + if not api_key_name or not api_key_value: + raise ToolProviderCredentialValidationError('Please input api key name and value') + + image_id = tool_paramters.get('image_id', '') + if not image_id: + return self.create_text_message('Please input image id') + + if image_id.startswith('__test_'): + image_binary = b64decode(VECTORIZER_ICON_PNG) + else: + image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + if not image_binary: + return self.create_text_message('Image not found, please request user to generate image firstly.') + + response = post( + 'https://vectorizer.ai/api/v1/vectorize', + files={ + 'image': image_binary + }, + data={ + 'mode': mode + } if mode == 'test' else {}, + auth=(api_key_name, api_key_value), + timeout=30 + ) + + if response.status_code != 200: + raise Exception(response.text) + + return [ + self.create_text_message('the vectorized svg is saved as an image.'), + self.create_blob_message(blob=response.content, + meta={'mime_type': 'image/svg+xml'}) + ] + + def get_runtime_parameters(self) -> List[ToolParamter]: + """ + override the runtime parameters + """ + return [ + ToolParamter.get_simple_instance( + name='image_id', + llm_description=f'the image id that you want to vectorize, \ + and the image id should be specified in \ + {[i.name for i in self.list_default_image_variables()]}', + type=ToolParamter.ToolParameterType.SELECT, + required=True, + options=[i.name for i in self.list_default_image_variables()] + ) + ] + + def is_tool_avaliable(self) -> bool: + return len(self.list_default_image_variables()) > 0 \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml new file mode 100644 index 0000000000..2df4e3bf2f --- /dev/null +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml @@ -0,0 +1,32 @@ +identity: + name: vectorizer + author: Dify + label: + en_US: Vectorizer.AI + zh_Hans: Vectorizer.AI +description: + human: + en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. + zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 + llm: A tool for converting images to SVG vectors. you should input the image id as the input of this tool. the image id can be got from parameters. +parameters: + - name: mode + type: select + required: true + options: + - value: production + label: + en_US: production + zh_Hans: 生产模式 + - value: test + label: + en_US: test + zh_Hans: 测试模式 + default: test + label: + en_US: Mode + zh_Hans: 模式 + human_description: + en_US: It is free to integrate with and test out the API in test mode, no subscription required. + zh_Hans: 在测试模式下,可以免费测试API。 + form: form diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py new file mode 100644 index 0000000000..6abf11fbe1 --- /dev/null +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -0,0 +1,23 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool + +from typing import Any, Dict + +class VectorizerProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + VectorizerTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "mode": "test", + "image_id": "__test_123" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml new file mode 100644 index 0000000000..566335e0c6 --- /dev/null +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml @@ -0,0 +1,36 @@ +identity: + author: Dify + name: vectorizer + label: + en_US: Vectorizer.AI + zh_Hans: Vectorizer.AI + description: + en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. + zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 + icon: icon.png +credentails_for_provider: + api_key_name: + type: secret-input + required: true + label: + en_US: Vectorizer.AI API Key name + zh_Hans: Vectorizer.AI API Key name + placeholder: + en_US: Please input your Vectorizer.AI ApiKey name + zh_Hans: 请输入你的 Vectorizer.AI ApiKey name + help: + en_US: Get your Vectorizer.AI API Key from Vectorizer.AI. + zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。 + url: https://vectorizer.ai/api + api_key_value: + type: secret-input + required: true + label: + en_US: Vectorizer.AI API Key + zh_Hans: Vectorizer.AI API Key + placeholder: + en_US: Please input your Vectorizer.AI ApiKey + zh_Hans: 请输入你的 Vectorizer.AI ApiKey + help: + en_US: Get your Vectorizer.AI API Key from Vectorizer.AI. + zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。 diff --git a/api/core/tools/provider/builtin/webscraper/_assets/icon.svg b/api/core/tools/provider/builtin/webscraper/_assets/icon.svg new file mode 100644 index 0000000000..8123199a38 --- /dev/null +++ b/api/core/tools/provider/builtin/webscraper/_assets/icon.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py new file mode 100644 index 0000000000..d646720a36 --- /dev/null +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py @@ -0,0 +1,28 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolInvokeError + +from typing import Any, Dict, List, Union + +class WebscraperTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + url = tool_paramters.get('url', '') + user_agent = tool_paramters.get('user_agent', '') + if not url: + return self.create_text_message('Please input url') + + # get webpage + result = self.get_url(url, user_agent=user_agent) + + # summarize and return + return self.create_text_message(self.summary(user_id=user_id, content=result)) + except Exception as e: + raise ToolInvokeError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml b/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml new file mode 100644 index 0000000000..93cf7a044d --- /dev/null +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.yaml @@ -0,0 +1,34 @@ +identity: + name: webscraper + author: Dify + label: + en_US: Web Scraper + zh_Hans: 网页爬虫 +description: + human: + en_US: A tool for scraping webpages. + zh_Hans: 一个用于爬取网页的工具。 + llm: A tool for scraping webpages. Input should be a URL. +parameters: + - name: url + type: string + required: true + label: + en_US: URL + zh_Hans: 网页链接 + human_description: + en_US: used for linking to webpages + zh_Hans: 用于链接到网页 + llm_description: url for scraping + form: llm + - name: user_agent + type: string + required: false + label: + en_US: User Agent + zh_Hans: User Agent + human_description: + en_US: used for identifying the browser. + zh_Hans: 用于识别浏览器。 + form: form + default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36 diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py new file mode 100644 index 0000000000..9cfb7ecac2 --- /dev/null +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -0,0 +1,23 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.webscraper.tools.webscraper import WebscraperTool + +from typing import Any, Dict, List + +class WebscraperProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + WebscraperTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + 'url': 'https://www.google.com', + 'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.yaml b/api/core/tools/provider/builtin/webscraper/webscraper.yaml new file mode 100644 index 0000000000..a1a99f99e7 --- /dev/null +++ b/api/core/tools/provider/builtin/webscraper/webscraper.yaml @@ -0,0 +1,11 @@ +identity: + author: Dify + name: webscraper + label: + en_US: WebScraper + zh_Hans: 网页抓取 + description: + en_US: Web Scrapper tool kit is used to scrape web + zh_Hans: 一个用于抓取网页的工具。 + icon: icon.svg +credentails_for_provider: diff --git a/api/core/tools/provider/builtin/wikipedia/_assets/icon.svg b/api/core/tools/provider/builtin/wikipedia/_assets/icon.svg new file mode 100644 index 0000000000..fe652aacf9 --- /dev/null +++ b/api/core/tools/provider/builtin/wikipedia/_assets/icon.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py new file mode 100644 index 0000000000..c2764455c6 --- /dev/null +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -0,0 +1,37 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from pydantic import BaseModel, Field + +from typing import Any, Dict, List, Union + +from langchain import WikipediaAPIWrapper +from langchain.tools import WikipediaQueryRun + +class WikipediaInput(BaseModel): + query: str = Field(..., description="search query.") + +class WikiPediaSearchTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + query = tool_paramters.get('query', '') + if not query: + return self.create_text_message('Please input query') + + tool = WikipediaQueryRun( + name="wikipedia", + api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), + args_schema=WikipediaInput + ) + + result = tool.run(tool_input={ + 'query': query + }) + + return self.create_text_message(self.summary(user_id=user_id,content=result)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.yaml b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.yaml new file mode 100644 index 0000000000..3bbc65d804 --- /dev/null +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.yaml @@ -0,0 +1,24 @@ +identity: + name: wikipedia_search + author: Dify + label: + en_US: WikipediaSearch + zh_Hans: 维基百科搜索 + icon: icon.svg +description: + human: + en_US: A tool for performing a Wikipedia search and extracting snippets and webpages. + zh_Hans: 一个用于执行维基百科搜索并提取片段和网页的工具。 + llm: A tool for performing a Wikipedia search and extracting snippets and webpages. Input should be a search query. +parameters: + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + human_description: + en_US: key words for searching + zh_Hans: 查询关键词 + llm_description: key words for searching + form: llm diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py new file mode 100644 index 0000000000..11af4de739 --- /dev/null +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -0,0 +1,20 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.wikipedia.tools.wikipedia_search import WikiPediaSearchTool + +class WikiPediaProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + WikiPediaSearchTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "query": "misaka mikoto", + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml b/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml new file mode 100644 index 0000000000..801d7d174b --- /dev/null +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.yaml @@ -0,0 +1,11 @@ +identity: + author: Dify + name: wikipedia + label: + en_US: Wikipedia + zh_Hans: 维基百科 + description: + en_US: Wikipedia is a free online encyclopedia, created and edited by volunteers around the world. + zh_Hans: 维基百科是一个由全世界的志愿者创建和编辑的免费在线百科全书。 + icon: icon.svg +credentails_for_provider: diff --git a/api/core/tools/provider/builtin/wolframalpha/_assets/icon.svg b/api/core/tools/provider/builtin/wolframalpha/_assets/icon.svg new file mode 100644 index 0000000000..2caf32ee67 --- /dev/null +++ b/api/core/tools/provider/builtin/wolframalpha/_assets/icon.svg @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py new file mode 100644 index 0000000000..787bca32ca --- /dev/null +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -0,0 +1,77 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolProviderCredentialValidationError, ToolInvokeError + +from typing import Any, Dict, List, Union + +from httpx import get + +class WolframAlphaTool(BuiltinTool): + _base_url = 'https://api.wolframalpha.com/v2/query' + + def _invoke(self, + user_id: str, + tool_paramters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + query = tool_paramters.get('query', '') + if not query: + return self.create_text_message('Please input query') + appid = self.runtime.credentials.get('appid', '') + if not appid: + raise ToolProviderCredentialValidationError('Please input appid') + + params = { + 'appid': appid, + 'input': query, + 'includepodid': 'Result', + 'format': 'plaintext', + 'output': 'json' + } + + finished = False + result = None + # try 3 times at most + counter = 0 + + while not finished and counter < 3: + counter += 1 + try: + response = get(self._base_url, params=params, timeout=20) + response.raise_for_status() + response_data = response.json() + except Exception as e: + raise ToolInvokeError(str(e)) + + if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True: + query_result = response_data.get('queryresult', {}) + if 'error' in query_result and query_result['error']: + if 'msg' in query_result['error']: + if query_result['error']['msg'] == 'Invalid appid': + raise ToolProviderCredentialValidationError('Invalid appid') + raise ToolInvokeError('Failed to invoke tool') + + if 'didyoumeans' in response_data['queryresult']: + # get the most likely interpretation + query = '' + max_score = 0 + for didyoumean in response_data['queryresult']['didyoumeans']: + if float(didyoumean['score']) > max_score: + query = didyoumean['val'] + max_score = float(didyoumean['score']) + + params['input'] = query + else: + finished = True + if 'souces' in response_data['queryresult']: + return self.create_link_message(response_data['queryresult']['sources']['url']) + elif 'pods' in response_data['queryresult']: + result = response_data['queryresult']['pods'][0]['subpods'][0]['plaintext'] + + if not finished or not result: + return self.create_text_message('No result found') + + return self.create_text_message(result) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.yaml b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.yaml new file mode 100644 index 0000000000..0e7e8fad92 --- /dev/null +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.yaml @@ -0,0 +1,23 @@ +identity: + name: wolframalpha + author: Dify + label: + en_US: WolframAlpha + zh_Hans: WolframAlpha +description: + human: + en_US: WolframAlpha is a powerful computational knowledge engine. + zh_Hans: WolframAlpha 是一个强大的计算知识引擎。 + llm: WolframAlpha is a powerful computational knowledge engine. one single query can get the answer of a question. +parameters: + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 计算语句 + human_description: + en_US: used for calculating + zh_Hans: 用于计算最终结果 + llm_description: a single query for calculating + form: llm diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py new file mode 100644 index 0000000000..56e3672572 --- /dev/null +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -0,0 +1,24 @@ +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType +from core.tools.tool.tool import Tool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.wolframalpha.tools.wolframalpha import WolframAlphaTool + +from typing import Any, Dict, List + +class GoogleProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + WolframAlphaTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "query": "1+2+....+111", + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml new file mode 100644 index 0000000000..1d2a038103 --- /dev/null +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.yaml @@ -0,0 +1,24 @@ +identity: + author: Dify + name: wolframalpha + label: + en_US: WolframAlpha + zh_Hans: WolframAlpha + description: + en_US: WolframAlpha is a powerful computational knowledge engine. + zh_Hans: WolframAlpha 是一个强大的计算知识引擎。 + icon: icon.svg +credentails_for_provider: + appid: + type: secret-input + required: true + label: + en_US: WolframAlpha AppID + zh_Hans: WolframAlpha AppID + placeholder: + en_US: Please input your WolframAlpha AppID + zh_Hans: 请输入你的 WolframAlpha AppID + help: + en_US: Get your WolframAlpha AppID from WolframAlpha, please use "full results" api access. + zh_Hans: 从 WolframAlpha 获取您的 WolframAlpha AppID,请使用 "full results" API。 + url: https://products.wolframalpha.com/api diff --git a/api/core/tools/provider/builtin/yahoo/_assets/icon.png b/api/core/tools/provider/builtin/yahoo/_assets/icon.png new file mode 100644 index 0000000000..35d756f754 Binary files /dev/null and b/api/core/tools/provider/builtin/yahoo/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py new file mode 100644 index 0000000000..09bf17fe63 --- /dev/null +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -0,0 +1,69 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from typing import Any, Dict, List, Union +from requests.exceptions import HTTPError, ReadTimeout +from datetime import datetime + +from yfinance import download +import pandas as pd + +class YahooFinanceAnalyticsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + symbol = tool_paramters.get('symbol', '') + if not symbol: + return self.create_text_message('Please input symbol') + + time_range = [None, None] + start_date = tool_paramters.get('start_date', '') + if start_date: + time_range[0] = start_date + else: + time_range[0] = '1800-01-01' + + end_date = tool_paramters.get('end_date', '') + if end_date: + time_range[1] = end_date + else: + time_range[1] = datetime.now().strftime('%Y-%m-%d') + + stock_data = download(symbol, start=time_range[0], end=time_range[1]) + max_segments = min(15, len(stock_data)) + rows_per_segment = len(stock_data) // max_segments + summary_data = [] + for i in range(max_segments): + start_idx = i * rows_per_segment + end_idx = (i + 1) * rows_per_segment if i < max_segments - 1 else len(stock_data) + segment_data = stock_data.iloc[start_idx:end_idx] + segment_summary = { + 'Start Date': segment_data.index[0], + 'End Date': segment_data.index[-1], + 'Average Close': segment_data['Close'].mean(), + 'Average Volume': segment_data['Volume'].mean(), + 'Average Open': segment_data['Open'].mean(), + 'Average High': segment_data['High'].mean(), + 'Average Low': segment_data['Low'].mean(), + 'Average Adj Close': segment_data['Adj Close'].mean(), + 'Max Close': segment_data['Close'].max(), + 'Min Close': segment_data['Close'].min(), + 'Max Volume': segment_data['Volume'].max(), + 'Min Volume': segment_data['Volume'].min(), + 'Max Open': segment_data['Open'].max(), + 'Min Open': segment_data['Open'].min(), + 'Max High': segment_data['High'].max(), + 'Min High': segment_data['High'].min(), + } + + summary_data.append(segment_summary) + + summary_df = pd.DataFrame(summary_data) + + try: + return self.create_text_message(str(summary_df.to_dict())) + except (HTTPError, ReadTimeout): + return self.create_text_message(f'There is a internet connection problem. Please try again later.') + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.yaml b/api/core/tools/provider/builtin/yahoo/tools/analytics.yaml new file mode 100644 index 0000000000..0daee2acf0 --- /dev/null +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.yaml @@ -0,0 +1,46 @@ +identity: + name: yahoo_finance_analytics + author: Dify + label: + en_US: Analytics + zh_Hans: 分析 + icon: icon.svg +description: + human: + en_US: A tool for get analytics about a ticker from Yahoo Finance. + zh_Hans: 一个用于从雅虎财经获取分析数据的工具。 + llm: A tool for get analytics from Yahoo Finance. Input should be the ticker symbol like AAPL. +parameters: + - name: symbol + type: string + required: true + label: + en_US: Ticker symbol + zh_Hans: 股票代码 + human_description: + en_US: The ticker symbol of the company you want to analyze. + zh_Hans: 你想要搜索的公司的股票代码。 + llm_description: The ticker symbol of the company you want to analyze. + form: llm + - name: start_date + type: string + required: false + label: + en_US: Start date + zh_Hans: 开始日期 + human_description: + en_US: The start date of the analytics. + zh_Hans: 分析的开始日期。 + llm_description: The start date of the analytics, the format of the date must be YYYY-MM-DD like 2020-01-01. + form: llm + - name: end_date + type: string + required: false + label: + en_US: End date + zh_Hans: 结束日期 + human_description: + en_US: The end date of the analytics. + zh_Hans: 分析的结束日期。 + llm_description: The end date of the analytics, the format of the date must be YYYY-MM-DD like 2024-01-01. + form: llm diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py new file mode 100644 index 0000000000..1fd1188607 --- /dev/null +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -0,0 +1,46 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from typing import Any, Dict, List, Union +from requests.exceptions import HTTPError, ReadTimeout + +import yfinance + +class YahooFinanceSearchTickerTool(BuiltinTool): + def _invoke(self,user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + ''' + invoke tools + ''' + + query = tool_paramters.get('symbol', '') + if not query: + return self.create_text_message('Please input symbol') + + try: + return self.run(ticker=query, user_id=user_id) + except (HTTPError, ReadTimeout): + return self.create_text_message(f'There is a internet connection problem. Please try again later.') + + def run(self, ticker: str, user_id: str) -> ToolInvokeMessage: + company = yfinance.Ticker(ticker) + try: + if company.isin is None: + return self.create_text_message(f'Company ticker {ticker} not found.') + except (HTTPError, ReadTimeout, ConnectionError): + return self.create_text_message(f'Company ticker {ticker} not found.') + + links = [] + try: + links = [n['link'] for n in company.news if n['type'] == 'STORY'] + except (HTTPError, ReadTimeout, ConnectionError): + if not links: + return self.create_text_message(f'There is nothing about {ticker} ticker') + if not links: + return self.create_text_message(f'No news found for company that searched with {ticker} ticker.') + + result = '\n\n'.join([ + self.get_url(link) for link in links + ]) + + return self.create_text_message(self.summary(user_id=user_id, content=result)) diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.yaml b/api/core/tools/provider/builtin/yahoo/tools/news.yaml new file mode 100644 index 0000000000..db33c96228 --- /dev/null +++ b/api/core/tools/provider/builtin/yahoo/tools/news.yaml @@ -0,0 +1,24 @@ +identity: + name: yahoo_finance_news + author: Dify + label: + en_US: News + zh_Hans: 新闻 + icon: icon.svg +description: + human: + en_US: A tool for get news about a ticker from Yahoo Finance. + zh_Hans: 一个用于从雅虎财经获取新闻的工具。 + llm: A tool for get news from Yahoo Finance. Input should be the ticker symbol like AAPL. +parameters: + - name: symbol + type: string + required: true + label: + en_US: Ticker symbol + zh_Hans: 股票代码 + human_description: + en_US: The ticker symbol of the company you want to search. + zh_Hans: 你想要搜索的公司的股票代码。 + llm_description: The ticker symbol of the company you want to search. + form: llm diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py new file mode 100644 index 0000000000..029cc7b446 --- /dev/null +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -0,0 +1,25 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from typing import Any, Dict, List, Union +from requests.exceptions import HTTPError, ReadTimeout + +from yfinance import Ticker + +class YahooFinanceSearchTickerTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + query = tool_paramters.get('symbol', '') + if not query: + return self.create_text_message('Please input symbol') + + try: + return self.create_text_message(self.run(ticker=query)) + except (HTTPError, ReadTimeout): + return self.create_text_message(f'There is a internet connection problem. Please try again later.') + + def run(self, ticker: str) -> str: + return str(Ticker(ticker).info) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.yaml b/api/core/tools/provider/builtin/yahoo/tools/ticker.yaml new file mode 100644 index 0000000000..b90cfa6327 --- /dev/null +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.yaml @@ -0,0 +1,24 @@ +identity: + name: yahoo_finance_ticker + author: Dify + label: + en_US: Ticker + zh_Hans: 股票信息 + icon: icon.svg +description: + human: + en_US: A tool for search ticker information from Yahoo Finance. + zh_Hans: 一个用于从雅虎财经搜索股票信息的工具。 + llm: A tool for search ticker information from Yahoo Finance. Input should be the ticker symbol like AAPL. +parameters: + - name: symbol + type: string + required: true + label: + en_US: Ticker symbol + zh_Hans: 股票代码 + human_description: + en_US: The ticker symbol of the company you want to search. + zh_Hans: 你想要搜索的公司的股票代码。 + llm_description: The ticker symbol of the company you want to search. + form: llm diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py new file mode 100644 index 0000000000..258ee3a6ce --- /dev/null +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -0,0 +1,20 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.yahoo.tools.ticker import YahooFinanceSearchTickerTool + +class YahooFinanceProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + YahooFinanceSearchTickerTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "ticker": "MSFT", + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.yaml b/api/core/tools/provider/builtin/yahoo/yahoo.yaml new file mode 100644 index 0000000000..d1802321fb --- /dev/null +++ b/api/core/tools/provider/builtin/yahoo/yahoo.yaml @@ -0,0 +1,11 @@ +identity: + author: Dify + name: yahoo + label: + en_US: YahooFinance + zh_Hans: 雅虎财经 + description: + en_US: Finance, and Yahoo! get the latest news, stock quotes, and interactive chart with Yahoo! + zh_Hans: 雅虎财经,获取并整理出最新的新闻、股票报价等一切你想要的财经信息。 + icon: icon.png +credentails_for_provider: diff --git a/api/core/tools/provider/builtin/youtube/_assets/icon.png b/api/core/tools/provider/builtin/youtube/_assets/icon.png new file mode 100644 index 0000000000..3ab7908a5d Binary files /dev/null and b/api/core/tools/provider/builtin/youtube/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py new file mode 100644 index 0000000000..ed48e94463 --- /dev/null +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -0,0 +1,66 @@ +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from typing import Any, Dict, List, Union +from datetime import datetime + +from googleapiclient.discovery import build + +class YoutubeVideosAnalyticsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \ + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + channel = tool_paramters.get('channel', '') + if not channel: + return self.create_text_message('Please input symbol') + + time_range = [None, None] + start_date = tool_paramters.get('start_date', '') + if start_date: + time_range[0] = start_date + else: + time_range[0] = '1800-01-01' + + end_date = tool_paramters.get('end_date', '') + if end_date: + time_range[1] = end_date + else: + time_range[1] = datetime.now().strftime('%Y-%m-%d') + + if 'google_api_key' not in self.runtime.credentials or not self.runtime.credentials['google_api_key']: + return self.create_text_message('Please input api key') + + youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key']) + + # try to get channel id + search_results = youtube.search().list(q='mrbeast', type='channel', order='relevance', part='id').execute() + channel_id = search_results['items'][0]['id']['channelId'] + + start_date, end_date = time_range + + start_date = datetime.strptime(start_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') + end_date = datetime.strptime(end_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') + + # get videos + time_range_videos = youtube.search().list( + part='snippet', channelId=channel_id, order='date', type='video', + publishedAfter=start_date, + publishedBefore=end_date + ).execute() + + def extract_video_data(video_list): + data = [] + for video in video_list['items']: + video_id = video['id']['videoId'] + video_info = youtube.videos().list(part='snippet,statistics', id=video_id).execute() + title = video_info['items'][0]['snippet']['title'] + views = video_info['items'][0]['statistics']['viewCount'] + data.append({'Title': title, 'Views': views}) + return data + + summary = extract_video_data(time_range_videos) + + return self.create_text_message(str(summary)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.yaml b/api/core/tools/provider/builtin/youtube/tools/videos.yaml new file mode 100644 index 0000000000..3ac28d5d7f --- /dev/null +++ b/api/core/tools/provider/builtin/youtube/tools/videos.yaml @@ -0,0 +1,46 @@ +identity: + name: youtube_video_statistics + author: Dify + label: + en_US: Video statistics + zh_Hans: 视频统计 + icon: icon.svg +description: + human: + en_US: A tool for get statistics about a channel's videos. + zh_Hans: 一个用于获取油管频道视频统计数据的工具。 + llm: A tool for get statistics about a channel's videos. Input should be the name of the channel like PewDiePie. +parameters: + - name: channel + type: string + required: true + label: + en_US: Channel name + zh_Hans: 频道名 + human_description: + en_US: The name of the channel you want to search. + zh_Hans: 你想要搜索的油管频道名。 + llm_description: The name of the channel you want to search. + form: llm + - name: start_date + type: string + required: false + label: + en_US: Start date + zh_Hans: 开始日期 + human_description: + en_US: The start date of the analytics. + zh_Hans: 分析的开始日期。 + llm_description: The start date of the analytics, the format of the date must be YYYY-MM-DD like 2020-01-01. + form: llm + - name: end_date + type: string + required: false + label: + en_US: End date + zh_Hans: 结束日期 + human_description: + en_US: The end date of the analytics. + zh_Hans: 分析的结束日期。 + llm_description: The end date of the analytics, the format of the date must be YYYY-MM-DD like 2024-01-01. + form: llm diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py new file mode 100644 index 0000000000..047cfbe0e0 --- /dev/null +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -0,0 +1,22 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.youtube.tools.videos import YoutubeVideosAnalyticsTool + +class YahooFinanceProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + YoutubeVideosAnalyticsTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_paramters={ + "channel": "TOKYO GIRLS COLLECTION", + "start_date": "2020-01-01", + "end_date": "2024-12-31", + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.yaml b/api/core/tools/provider/builtin/youtube/youtube.yaml new file mode 100644 index 0000000000..e5d3879d52 --- /dev/null +++ b/api/core/tools/provider/builtin/youtube/youtube.yaml @@ -0,0 +1,24 @@ +identity: + author: Dify + name: youtube + label: + en_US: Youtube + zh_Hans: Youtube + description: + en_US: Youtube + zh_Hans: Youtube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。 + icon: icon.png +credentails_for_provider: + google_api_key: + type: secret-input + required: true + label: + en_US: Google API key + zh_Hans: Google API key + placeholder: + en_US: Please input your Google API key + zh_Hans: 请输入你的 Google API key + help: + en_US: Get your Google API key from Google + zh_Hans: 从 Google 获取您的 Google API key + url: https://console.developers.google.com/apis/credentials diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py new file mode 100644 index 0000000000..e071dcbea0 --- /dev/null +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -0,0 +1,286 @@ +from abc import abstractmethod +from typing import List, Dict, Any + +from os import path, listdir +from yaml import load, FullLoader + +from core.tools.entities.tool_entities import ToolProviderType, \ + ToolParamter, ToolProviderCredentials +from core.tools.tool.tool import Tool +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.entities.user_entities import UserToolProviderCredentials +from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError, \ + ToolParamterValidationError, ToolProviderCredentialValidationError + +import importlib + +class BuiltinToolProviderController(ToolProviderController): + def __init__(self, **data: Any) -> None: + if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED: + super().__init__(**data) + return + + # load provider yaml + provider = self.__class__.__module__.split('.')[-1] + yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') + try: + with open(yaml_path, 'r') as f: + provider_yaml = load(f.read(), FullLoader) + except: + raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}') + + if 'credentails_for_provider' in provider_yaml and provider_yaml['credentails_for_provider'] is not None: + # set credentials name + for credential_name in provider_yaml['credentails_for_provider']: + provider_yaml['credentails_for_provider'][credential_name]['name'] = credential_name + + super().__init__(**{ + 'identity': provider_yaml['identity'], + 'credentials_schema': provider_yaml['credentails_for_provider'] if 'credentails_for_provider' in provider_yaml else None, + }) + + def _get_bulitin_tools(self) -> List[Tool]: + """ + returns a list of tools that the provider can provide + + :return: list of tools + """ + if self.tools: + return self.tools + + provider = self.identity.name + tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") + # get all the yaml files in the tool path + tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path))) + tools = [] + for tool_file in tool_files: + with open(path.join(tool_path, tool_file), "r") as f: + # get tool name + tool_name = tool_file.split(".")[0] + tool = load(f.read(), FullLoader) + # get tool class, import the module + py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py') + spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + # get all the classes in the module + classes = [x for _, x in vars(mod).items() + if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool) + ] + assistant_tool_class = classes[0] + tools.append(assistant_tool_class(**tool)) + + self.tools = tools + return tools + + def get_credentails_schema(self) -> Dict[str, ToolProviderCredentials]: + """ + returns the credentials schema of the provider + + :return: the credentials schema + """ + if not self.credentials_schema: + return {} + + return self.credentials_schema.copy() + + def user_get_credentails_schema(self) -> UserToolProviderCredentials: + """ + returns the credentials schema of the provider, this method is used for user + + :return: the credentials schema + """ + credentials = self.credentials_schema.copy() + return UserToolProviderCredentials(credentails=credentials) + + def get_tools(self) -> List[Tool]: + """ + returns a list of tools that the provider can provide + + :return: list of tools + """ + return self._get_bulitin_tools() + + def get_tool(self, tool_name: str) -> Tool: + """ + returns the tool that the provider can provide + """ + return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + + def get_parameters(self, tool_name: str) -> List[ToolParamter]: + """ + returns the parameters of the tool + + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters + """ + tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + if tool is None: + raise ToolNotFoundError(f'tool {tool_name} not found') + return tool.parameters + + @property + def need_credentials(self) -> bool: + """ + returns whether the provider needs credentials + + :return: whether the provider needs credentials + """ + return self.credentials_schema is not None and len(self.credentials_schema) != 0 + + @property + def app_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.BUILT_IN + + def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: Dict[str, Any]) -> None: + """ + validate the parameters of the tool and set the default value if needed + + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool + """ + tool_parameters_schema = self.get_parameters(tool_name) + + tool_parameters_need_to_validate: Dict[str, ToolParamter] = {} + for parameter in tool_parameters_schema: + tool_parameters_need_to_validate[parameter.name] = parameter + + for parameter in tool_parameters: + if parameter not in tool_parameters_need_to_validate: + raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}') + + # check type + parameter_schema = tool_parameters_need_to_validate[parameter] + if parameter_schema.type == ToolParamter.ToolParameterType.STRING: + if not isinstance(tool_parameters[parameter], str): + raise ToolParamterValidationError(f'parameter {parameter} should be string') + + elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER: + if not isinstance(tool_parameters[parameter], (int, float)): + raise ToolParamterValidationError(f'parameter {parameter} should be number') + + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') + + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') + + elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN: + if not isinstance(tool_parameters[parameter], bool): + raise ToolParamterValidationError(f'parameter {parameter} should be boolean') + + elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT: + if not isinstance(tool_parameters[parameter], str): + raise ToolParamterValidationError(f'parameter {parameter} should be string') + + options = parameter_schema.options + if not isinstance(options, list): + raise ToolParamterValidationError(f'parameter {parameter} options should be list') + + if tool_parameters[parameter] not in [x.value for x in options]: + raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}') + + tool_parameters_need_to_validate.pop(parameter) + + for parameter in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[parameter] + if parameter_schema.required: + raise ToolParamterValidationError(f'parameter {parameter} is required') + + # the parameter is not set currently, set the default value if needed + if parameter_schema.default is not None: + default_value = parameter_schema.default + # parse default value into the correct type + if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \ + parameter_schema.type == ToolParamter.ToolParameterType.SELECT: + default_value = str(default_value) + elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER: + default_value = float(default_value) + elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN: + default_value = bool(default_value) + + tool_parameters[parameter] = default_value + + def validate_credentials_format(self, credentials: Dict[str, Any]) -> None: + """ + validate the format of the credentials of the provider and set the default value if needed + + :param credentials: the credentials of the tool + """ + credentials_schema = self.credentials_schema + if credentials_schema is None: + return + + credentials_need_to_validate: Dict[str, ToolProviderCredentials] = {} + for credential_name in credentials_schema: + credentials_need_to_validate[credential_name] = credentials_schema[credential_name] + + for credential_name in credentials: + if credential_name not in credentials_need_to_validate: + raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}') + + # check type + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ + credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') + + elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') + + options = credential_schema.options + if not isinstance(options, list): + raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list') + + if credentials[credential_name] not in [x.value for x in options]: + raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}') + + credentials_need_to_validate.pop(credential_name) + + for credential_name in credentials_need_to_validate: + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema.required: + raise ToolProviderCredentialValidationError(f'credential {credential_name} is required') + + # the credential is not set currently, set the default value if needed + if credential_schema.default is not None: + default_value = credential_schema.default + # parse default value into the correct type + if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ + credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ + credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + default_value = str(default_value) + + credentials[credential_name] = default_value + + def validate_credentials(self, credentials: Dict[str, Any]) -> None: + """ + validate the credentials of the provider + + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool + """ + # validate credentials format + self.validate_credentials_format(credentials) + + # validate credentials + self._validate_credentials(credentials) + + @abstractmethod + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + """ + validate the credentials of the provider + + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool + """ + pass \ No newline at end of file diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py new file mode 100644 index 0000000000..722b1662a8 --- /dev/null +++ b/api/core/tools/provider/tool_provider.py @@ -0,0 +1,218 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional + +from pydantic import BaseModel + +from core.tools.entities.tool_entities import ToolProviderType, \ + ToolProviderIdentity, ToolParamter, ToolProviderCredentials +from core.tools.tool.tool import Tool +from core.tools.entities.user_entities import UserToolProviderCredentials +from core.tools.errors import ToolNotFoundError, \ + ToolParamterValidationError, ToolProviderCredentialValidationError + +class ToolProviderController(BaseModel, ABC): + identity: Optional[ToolProviderIdentity] = None + tools: Optional[List[Tool]] = None + credentials_schema: Optional[Dict[str, ToolProviderCredentials]] = None + + def get_credentails_schema(self) -> Dict[str, ToolProviderCredentials]: + """ + returns the credentials schema of the provider + + :return: the credentials schema + """ + return self.credentials_schema.copy() + + def user_get_credentails_schema(self) -> UserToolProviderCredentials: + """ + returns the credentials schema of the provider, this method is used for user + + :return: the credentials schema + """ + credentials = self.credentials_schema.copy() + return UserToolProviderCredentials(credentails=credentials) + + @abstractmethod + def get_tools(self) -> List[Tool]: + """ + returns a list of tools that the provider can provide + + :return: list of tools + """ + pass + + @abstractmethod + def get_tool(self, tool_name: str) -> Tool: + """ + returns a tool that the provider can provide + + :return: tool + """ + pass + + def get_parameters(self, tool_name: str) -> List[ToolParamter]: + """ + returns the parameters of the tool + + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters + """ + tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + if tool is None: + raise ToolNotFoundError(f'tool {tool_name} not found') + return tool.parameters + + @property + def app_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.BUILT_IN + + def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: Dict[str, Any]) -> None: + """ + validate the parameters of the tool and set the default value if needed + + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool + """ + tool_parameters_schema = self.get_parameters(tool_name) + + tool_parameters_need_to_validate: Dict[str, ToolParamter] = {} + for parameter in tool_parameters_schema: + tool_parameters_need_to_validate[parameter.name] = parameter + + for parameter in tool_parameters: + if parameter not in tool_parameters_need_to_validate: + raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}') + + # check type + parameter_schema = tool_parameters_need_to_validate[parameter] + if parameter_schema.type == ToolParamter.ToolParameterType.STRING: + if not isinstance(tool_parameters[parameter], str): + raise ToolParamterValidationError(f'parameter {parameter} should be string') + + elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER: + if not isinstance(tool_parameters[parameter], (int, float)): + raise ToolParamterValidationError(f'parameter {parameter} should be number') + + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') + + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') + + elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN: + if not isinstance(tool_parameters[parameter], bool): + raise ToolParamterValidationError(f'parameter {parameter} should be boolean') + + elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT: + if not isinstance(tool_parameters[parameter], str): + raise ToolParamterValidationError(f'parameter {parameter} should be string') + + options = parameter_schema.options + if not isinstance(options, list): + raise ToolParamterValidationError(f'parameter {parameter} options should be list') + + if tool_parameters[parameter] not in [x.value for x in options]: + raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}') + + tool_parameters_need_to_validate.pop(parameter) + + for parameter in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[parameter] + if parameter_schema.required: + raise ToolParamterValidationError(f'parameter {parameter} is required') + + # the parameter is not set currently, set the default value if needed + if parameter_schema.default is not None: + default_value = parameter_schema.default + # parse default value into the correct type + if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \ + parameter_schema.type == ToolParamter.ToolParameterType.SELECT: + default_value = str(default_value) + elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER: + default_value = float(default_value) + elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN: + default_value = bool(default_value) + + tool_parameters[parameter] = default_value + + def validate_credentials_format(self, credentials: Dict[str, Any]) -> None: + """ + validate the format of the credentials of the provider and set the default value if needed + + :param credentials: the credentials of the tool + """ + credentials_schema = self.credentials_schema + if credentials_schema is None: + return + + credentials_need_to_validate: Dict[str, ToolProviderCredentials] = {} + for credential_name in credentials_schema: + credentials_need_to_validate[credential_name] = credentials_schema[credential_name] + + for credential_name in credentials: + if credential_name not in credentials_need_to_validate: + raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}') + + # check type + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ + credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') + + elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') + + options = credential_schema.options + if not isinstance(options, list): + raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list') + + if credentials[credential_name] not in [x.value for x in options]: + raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}') + + credentials_need_to_validate.pop(credential_name) + + for credential_name in credentials_need_to_validate: + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema.required: + raise ToolProviderCredentialValidationError(f'credential {credential_name} is required') + + # the credential is not set currently, set the default value if needed + if credential_schema.default is not None: + default_value = credential_schema.default + # parse default value into the correct type + if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ + credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ + credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + default_value = str(default_value) + + credentials[credential_name] = default_value + + def validate_credentials(self, credentials: Dict[str, Any]) -> None: + """ + validate the credentials of the provider + + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool + """ + # validate credentials format + self.validate_credentials_format(credentials) + + # validate credentials + self._validate_credentials(credentials) + + @abstractmethod + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + """ + validate the credentials of the provider + + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool + """ + pass \ No newline at end of file diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py new file mode 100644 index 0000000000..896adc5b81 --- /dev/null +++ b/api/core/tools/tool/api_tool.py @@ -0,0 +1,222 @@ +from typing import Any, Dict, List, Union +from json import dumps + +from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.tool import Tool +from core.tools.errors import ToolProviderCredentialValidationError + +import httpx +import requests + +class ApiTool(Tool): + api_bundle: ApiBasedToolBundle + + """ + Api tool + """ + def fork_tool_runtime(self, meta: Dict[str, Any]) -> 'Tool': + """ + fork a new tool with meta data + + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool + """ + return self.__class__( + identity=self.identity.copy() if self.identity else None, + parameters=self.parameters.copy() if self.parameters else None, + description=self.description.copy() if self.description else None, + api_bundle=self.api_bundle.copy() if self.api_bundle else None, + runtime=Tool.Runtime(**meta) + ) + + def validate_credentials(self, credentails: Dict[str, Any], parameters: Dict[str, Any], format_only: bool = False) -> None: + """ + validate the credentials for Api tool + """ + # assemble validate request and request parameters + headers = self.assembling_request(parameters) + + if format_only: + return + + response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) + # validate response + self.validate_and_parse_response(response) + + def assembling_request(self, parameters: Dict[str, Any]) -> Dict[str, Any]: + headers = {} + credentials = self.runtime.credentials or {} + + if 'auth_type' not in credentials: + raise ToolProviderCredentialValidationError('Missing auth_type') + + if credentials['auth_type'] == 'api_key': + api_key_header = 'api_key' + + if 'api_key_header' in credentials: + api_key_header = credentials['api_key_header'] + + if 'api_key_value' not in credentials: + raise ToolProviderCredentialValidationError('Missing api_key_value') + + headers[api_key_header] = credentials['api_key_value'] + + needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] + for parameter in needed_parameters: + if parameter.required and parameter.name not in parameters: + raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter.name}") + + if parameter.default is not None and parameter.name not in parameters: + parameters[parameter.name] = parameter.default + + return headers + + def validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> str: + """ + validate the response + """ + if isinstance(response, httpx.Response): + if response.status_code >= 400: + raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}") + return response.text + elif isinstance(response, requests.Response): + if not response.ok: + raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}") + return response.text + else: + raise ValueError(f'Invalid response type {type(response)}') + + def do_http_request(self, url: str, method: str, headers: Dict[str, Any], parameters: Dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + method = method.lower() + + params = {} + path_params = {} + body = {} + cookies = {} + + # check parameters + for parameter in self.api_bundle.openapi.get('parameters', []): + if parameter['in'] == 'path': + value = '' + if parameter['name'] in parameters: + value = parameters[parameter['name']] + elif parameter['required']: + raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}") + path_params[parameter['name']] = value + + elif parameter['in'] == 'query': + value = '' + if parameter['name'] in parameters: + value = parameters[parameter['name']] + elif parameter['required']: + raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}") + params[parameter['name']] = value + + elif parameter['in'] == 'cookie': + value = '' + if parameter['name'] in parameters: + value = parameters[parameter['name']] + elif parameter['required']: + raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}") + cookies[parameter['name']] = value + + elif parameter['in'] == 'header': + value = '' + if parameter['name'] in parameters: + value = parameters[parameter['name']] + elif parameter['required']: + raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}") + headers[parameter['name']] = value + + # check if there is a request body and handle it + if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None: + # handle json request body + if 'content' in self.api_bundle.openapi['requestBody']: + for content_type in self.api_bundle.openapi['requestBody']['content']: + headers['Content-Type'] = content_type + body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] + required = body_schema['required'] if 'required' in body_schema else [] + properties = body_schema['properties'] if 'properties' in body_schema else {} + for name, property in properties.items(): + if name in parameters: + # convert type + try: + value = parameters[name] + if property['type'] == 'integer': + value = int(value) + elif property['type'] == 'number': + # check if it is a float + if '.' in value: + value = float(value) + else: + value = int(value) + elif property['type'] == 'boolean': + value = bool(value) + body[name] = value + except ValueError as e: + body[name] = parameters[name] + elif name in required: + raise ToolProviderCredentialValidationError( + f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" + ) + elif 'default' in property: + body[name] = property['default'] + else: + body[name] = None + break + + # replace path parameters + for name, value in path_params.items(): + url = url.replace(f'{{{name}}}', value) + + # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored + if 'Content-Type' in headers: + if headers['Content-Type'] == 'application/json': + body = dumps(body) + else: + body = body + + # do http request + if method == 'get': + response = httpx.get(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) + elif method == 'post': + response = httpx.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) + elif method == 'put': + response = httpx.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) + elif method == 'delete': + """ + request body data is unsupported for DELETE method in standard http protocol + however, OpenAPI 3.0 supports request body data for DELETE method, so we support it here by using requests + """ + response = requests.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, allow_redirects=True) + elif method == 'patch': + response = httpx.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) + elif method == 'head': + response = httpx.head(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) + elif method == 'options': + response = httpx.options(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) + else: + raise ValueError(f'Invalid http method {method}') + + return response + + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]: + """ + invoke http request + """ + # assemble request + headers = self.assembling_request(tool_paramters) + + # do http request + response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_paramters) + + # validate response + response = self.validate_and_parse_response(response) + + # assemble invoke message + return self.create_text_message(response) + \ No newline at end of file diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py new file mode 100644 index 0000000000..11116dbad2 --- /dev/null +++ b/api/core/tools/tool/builtin_tool.py @@ -0,0 +1,140 @@ +from core.tools.tool.tool import Tool +from core.tools.model.tool_model_manager import ToolModelManager +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage +from core.tools.utils.web_reader_tool import get_url + +from typing import List +from enum import Enum + +_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language +and you can quickly aimed at the main point of an webpage and reproduce it in your own words but +retain the original meaning and keep the key points. +however, the text you got is too long, what you got is possible a part of the text. +Please summarize the text you got. +""" + + +class BuiltinTool(Tool): + """ + Builtin tool + + :param meta: the meta data of a tool call processing + """ + + def invoke_model( + self, user_id: str, prompt_messages: List[PromptMessage], stop: List[str] + ) -> LLMResult: + """ + invoke model + + :param model_config: the model config + :param prompt_messages: the prompt messages + :param stop: the stop words + :return: the model result + """ + # invoke model + return ToolModelManager.invoke( + user_id=user_id, + tenant_id=self.runtime.tenant_id, + tool_type='builtin', + tool_name=self.identity.name, + prompt_messages=prompt_messages, + ) + + def get_max_tokens(self) -> int: + """ + get max tokens + + :param model_config: the model config + :return: the max tokens + """ + return ToolModelManager.get_max_llm_context_tokens( + tenant_id=self.runtime.tenant_id, + ) + + def get_prompt_tokens(self, prompt_messages: List[PromptMessage]) -> int: + """ + get prompt tokens + + :param prompt_messages: the prompt messages + :return: the tokens + """ + return ToolModelManager.calculate_tokens( + tenant_id=self.runtime.tenant_id, + prompt_messages=prompt_messages + ) + + def summary(self, user_id: str, content: str) -> str: + max_tokens = self.get_max_tokens() + + if self.get_prompt_tokens(prompt_messages=[ + UserPromptMessage(content=content) + ]) < max_tokens * 0.6: + return content + + def get_prompt_tokens(content: str) -> int: + return self.get_prompt_tokens(prompt_messages=[ + SystemPromptMessage(content=_SUMMARY_PROMPT), + UserPromptMessage(content=content) + ]) + + def summarize(content: str) -> str: + summary = self.invoke_model(user_id=user_id, prompt_messages=[ + SystemPromptMessage(content=_SUMMARY_PROMPT), + UserPromptMessage(content=content) + ], stop=[]) + + return summary.message.content + + lines = content.split('\n') + new_lines = [] + # split long line into multiple lines + for i in range(len(lines)): + line = lines[i] + if not line.strip(): + continue + if len(line) < max_tokens * 0.5: + new_lines.append(line) + elif get_prompt_tokens(line) > max_tokens * 0.7: + while get_prompt_tokens(line) > max_tokens * 0.7: + new_lines.append(line[:int(max_tokens * 0.5)]) + line = line[int(max_tokens * 0.5):] + new_lines.append(line) + else: + new_lines.append(line) + + # merge lines into messages with max tokens + messages: List[str] = [] + for i in new_lines: + if len(messages) == 0: + messages.append(i) + else: + if len(messages[-1]) + len(i) < max_tokens * 0.5: + messages[-1] += i + if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: + messages.append(i) + else: + messages[-1] += i + + summaries = [] + for i in range(len(messages)): + message = messages[i] + summary = summarize(message) + summaries.append(summary) + + result = '\n'.join(summaries) + + if self.get_prompt_tokens(prompt_messages=[ + UserPromptMessage(content=result) + ]) > max_tokens * 0.7: + return self.summary(user_id=user_id, content=result) + + return result + + def get_url(self, url: str, user_agent: str = None) -> str: + """ + get url + """ + return get_url(url, user_agent=user_agent) \ No newline at end of file diff --git a/api/core/tool/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py similarity index 99% rename from api/core/tool/dataset_multi_retriever_tool.py rename to api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index c9ca7eb04a..e205401686 100644 --- a/api/core/tool/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -246,4 +246,4 @@ class DatasetMultiRetrieverTool(BaseTool): for thread in threads: thread.join() - all_documents.extend(documents) + all_documents.extend(documents) \ No newline at end of file diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py similarity index 99% rename from api/core/tool/dataset_retriever_tool.py rename to api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 9049b5e691..79de38ca14 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -233,4 +233,4 @@ class DatasetRetrieverTool(BaseTool): return str("\n".join(document_context_list)) async def _arun(self, tool_input: str) -> str: - raise NotImplementedError() + raise NotImplementedError() \ No newline at end of file diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py new file mode 100644 index 0000000000..800952226a --- /dev/null +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -0,0 +1,95 @@ +from typing import Any, Dict, List, Union +from core.features.dataset_retrieval import DatasetRetrievalFeature +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter, ToolIdentity, ToolDescription +from core.tools.tool.tool import Tool +from core.tools.entities.common_entities import I18nObject +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom + +from langchain.tools import BaseTool + +class DatasetRetrieverTool(Tool): + langchain_tool: BaseTool + + @staticmethod + def get_dataset_tools(tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler + ) -> List['DatasetRetrieverTool']: + """ + get dataset tool + """ + # check if retrieve_config is valid + if dataset_ids is None or len(dataset_ids) == 0: + return [] + if retrieve_config is None: + return [] + + feature = DatasetRetrievalFeature() + + # save original retrieve strategy, and set retrieve strategy to SINGLE + # Agent only support SINGLE mode + original_retriever_mode = retrieve_config.retrieve_strategy + retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + langchain_tools = feature.to_dataset_retriever_tool( + tenant_id=tenant_id, + dataset_ids=dataset_ids, + retrieve_config=retrieve_config, + return_resource=return_resource, + invoke_from=invoke_from, + hit_callback=hit_callback + ) + # restore retrieve strategy + retrieve_config.retrieve_strategy = original_retriever_mode + + # convert langchain tools to Tools + tools = [] + for langchain_tool in langchain_tools: + tool = DatasetRetrieverTool( + langchain_tool=langchain_tool, + identity=ToolIdentity(author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')), + parameters=[], + is_team_authorization=True, + description=ToolDescription( + human=I18nObject(en_US='', zh_Hans=''), + llm=langchain_tool.description), + runtime=DatasetRetrieverTool.Runtime() + ) + + tools.append(tool) + + return tools + + def get_runtime_parameters(self) -> List[ToolParamter]: + return [ + ToolParamter(name='query', + label=I18nObject(en_US='', zh_Hans=''), + human_description=I18nObject(en_US='', zh_Hans=''), + type=ToolParamter.ToolParameterType.STRING, + form=ToolParamter.ToolParameterForm.LLM, + llm_description='Query for the dataset to be used to retrieve the dataset.', + required=True, + default=''), + ] + + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]: + """ + invoke dataset retriever tool + """ + query = tool_paramters.get('query', None) + if not query: + return self.create_text_message(text='please input query') + + # invoke dataset retriever tool + result = self.langchain_tool._run(query=query) + + return self.create_text_message(text=result) + + def validate_credentials(self, credentails: Dict[str, Any], parameters: Dict[str, Any]) -> None: + """ + validate the credentials for dataset retriever tool + """ + pass \ No newline at end of file diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py new file mode 100644 index 0000000000..4e29b427a9 --- /dev/null +++ b/api/core/tools/tool/tool.py @@ -0,0 +1,302 @@ +from pydantic import BaseModel + +from typing import List, Dict, Any, Union, Optional +from abc import abstractmethod, ABC +from enum import Enum + +from core.tools.entities.tool_entities import ToolIdentity, ToolInvokeMessage,\ + ToolParamter, ToolDescription, ToolRuntimeVariablePool, ToolRuntimeVariable, ToolRuntimeImageVariable +from core.tools.tool_file_manager import ToolFileManager +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler + +class Tool(BaseModel, ABC): + identity: ToolIdentity = None + parameters: Optional[List[ToolParamter]] = None + description: ToolDescription = None + is_team_authorization: bool = False + agent_callback: Optional[DifyAgentCallbackHandler] = None + use_callback: bool = False + + class Runtime(BaseModel): + """ + Meta data of a tool call processing + """ + def __init__(self, **data: Any): + super().__init__(**data) + if not self.runtime_parameters: + self.runtime_parameters = {} + + tenant_id: str = None + tool_id: str = None + credentials: Dict[str, Any] = None + runtime_parameters: Dict[str, Any] = None + + runtime: Runtime = None + variables: ToolRuntimeVariablePool = None + + def __init__(self, **data: Any): + super().__init__(**data) + + if not self.agent_callback: + self.use_callback = False + else: + self.use_callback = True + + class VARIABLE_KEY(Enum): + IMAGE = 'image' + + def fork_tool_runtime(self, meta: Dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool': + """ + fork a new tool with meta data + + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool + """ + return self.__class__( + identity=self.identity.copy() if self.identity else None, + parameters=self.parameters.copy() if self.parameters else None, + description=self.description.copy() if self.description else None, + runtime=Tool.Runtime(**meta), + agent_callback=agent_callback + ) + + def load_variables(self, variables: ToolRuntimeVariablePool): + """ + load variables from database + + :param conversation_id: the conversation id + """ + self.variables = variables + + def set_image_variable(self, variable_name: str, image_key: str) -> None: + """ + set an image variable + """ + if not self.variables: + return + + self.variables.set_file(self.identity.name, variable_name, image_key) + + def set_text_variable(self, variable_name: str, text: str) -> None: + """ + set a text variable + """ + if not self.variables: + return + + self.variables.set_text(self.identity.name, variable_name, text) + + def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: + """ + get a variable + + :param name: the name of the variable + :return: the variable + """ + if not self.variables: + return None + + if isinstance(name, Enum): + name = name.value + + for variable in self.variables.pool: + if variable.name == name: + return variable + + return None + + def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: + """ + get the default image variable + + :return: the image variable + """ + if not self.variables: + return None + + return self.get_variable(self.VARIABLE_KEY.IMAGE) + + def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: + """ + get a variable file + + :param name: the name of the variable + :return: the variable file + """ + variable = self.get_variable(name) + if not variable: + return None + + if not isinstance(variable, ToolRuntimeImageVariable): + return None + + message_file_id = variable.value + # get file binary + file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) + if not file_binary: + return None + + return file_binary[0] + + def list_variables(self) -> List[ToolRuntimeVariable]: + """ + list all variables + + :return: the variables + """ + if not self.variables: + return [] + + return self.variables.pool + + def list_default_image_variables(self) -> List[ToolRuntimeVariable]: + """ + list all image variables + + :return: the image variables + """ + if not self.variables: + return [] + + result = [] + + for variable in self.variables.pool: + if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): + result.append(variable) + + return result + + def invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> List[ToolInvokeMessage]: + # update tool_paramters + if self.runtime.runtime_parameters: + tool_paramters.update(self.runtime.runtime_parameters) + + # hit callback + if self.use_callback: + self.agent_callback.on_tool_start( + tool_name=self.identity.name, + tool_inputs=tool_paramters + ) + + try: + result = self._invoke( + user_id=user_id, + tool_paramters=tool_paramters, + ) + except Exception as e: + if self.use_callback: + self.agent_callback.on_tool_error(e) + raise e + + if not isinstance(result, list): + result = [result] + + # hit callback + if self.use_callback: + self.agent_callback.on_tool_end( + tool_name=self.identity.name, + tool_inputs=tool_paramters, + tool_outputs=self._convert_tool_response_to_str(result) + ) + + return result + + def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str: + """ + Handle tool response + """ + result = '' + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + result += response.message + elif response.type == ToolInvokeMessage.MessageType.LINK: + result += f"result link: {response.message}. please dirct user to check it." + elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + result += f"image has been created and sent to user already, you should tell user to check it now." + elif response.type == ToolInvokeMessage.MessageType.BLOB: + if len(response.message) > 114: + result += str(response.message[:114]) + '...' + else: + result += str(response.message) + else: + result += f"tool response: {response.message}." + + return result + + @abstractmethod + def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + pass + + def validate_credentials(self, credentails: Dict[str, Any], parameters: Dict[str, Any]) -> None: + """ + validate the credentials + + :param credentails: the credentials + :param parameters: the parameters + """ + pass + + def get_runtime_parameters(self) -> List[ToolParamter]: + """ + get the runtime parameters + + interface for developer to dynamic change the parameters of a tool depends on the variables pool + + :return: the runtime parameters + """ + return self.parameters + + def is_tool_avaliable(self) -> bool: + """ + check if the tool is avaliable + + :return: if the tool is avaliable + """ + return True + + def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + """ + create an image message + + :param image: the url of the image + :return: the image message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, + message=image, + save_as=save_as) + + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: + """ + create a link message + + :param link: the url of the link + :return: the link message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, + message=link, + save_as=save_as) + + def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: + """ + create a text message + + :param text: the text + :return: the text message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, + message=text, + save_as=save_as + ) + + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :return: the blob message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, + message=blob, meta=meta, + save_as=save_as + ) \ No newline at end of file diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py new file mode 100644 index 0000000000..43d74622ac --- /dev/null +++ b/api/core/tools/tool_file_manager.py @@ -0,0 +1,197 @@ +import logging +import time +import os +import hmac +import base64 +import hashlib + +from typing import Union, Tuple, Generator +from uuid import uuid4 +from mimetypes import guess_extension, guess_type +from httpx import get + +from flask import current_app + +from models.tools import ToolFile +from models.model import MessageFile + +from extensions.ext_database import db +from extensions.ext_storage import storage + +logger = logging.getLogger(__name__) + +class ToolFileManager: + @staticmethod + def sign_file(file_id: str, extension: str) -> str: + """ + sign file to get a temporary url + """ + base_url = current_app.config.get('FILES_URL') + file_preview_url = f'{base_url}/files/tools/{file_id}{extension}' + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = current_app.config['SECRET_KEY'].encode() + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + @staticmethod + def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + """ + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = current_app.config['SECRET_KEY'].encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= 300 # expired after 5 minutes + + @staticmethod + def create_file_by_raw(user_id: str, tenant_id: str, + conversation_id: str, file_binary: bytes, + mimetype: str + ) -> ToolFile: + """ + create file + """ + extension = guess_extension(mimetype) or '.bin' + unique_name = uuid4().hex + filename = f"/tools/{tenant_id}/{unique_name}{extension}" + storage.save(filename, file_binary) + + tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, + conversation_id=conversation_id, file_key=filename, mimetype=mimetype) + + db.session.add(tool_file) + db.session.commit() + + return tool_file + + @staticmethod + def create_file_by_url(user_id: str, tenant_id: str, + conversation_id: str, file_url: str, + ) -> ToolFile: + """ + create file + """ + # try to download image + response = get(file_url) + response.raise_for_status() + blob = response.content + mimetype = guess_type(file_url)[0] or 'octet/stream' + extension = guess_extension(mimetype) or '.bin' + unique_name = uuid4().hex + filename = f"/tools/{tenant_id}/{unique_name}{extension}" + storage.save(filename, blob) + + tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, + conversation_id=conversation_id, file_key=filename, + mimetype=mimetype, original_url=file_url) + + db.session.add(tool_file) + db.session.commit() + + return tool_file + + @staticmethod + def create_file_by_key(user_id: str, tenant_id: str, + conversation_id: str, file_key: str, + mimetype: str + ) -> ToolFile: + """ + create file + """ + tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, + conversation_id=conversation_id, file_key=file_key, mimetype=mimetype) + return tool_file + + @staticmethod + def get_file_binary(id: str) -> Union[Tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + tool_file: ToolFile = db.session.query(ToolFile).filter( + ToolFile.id == id, + ).first() + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_binary_by_message_file_id(id: str) -> Union[Tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + message_file: MessageFile = db.session.query(MessageFile).filter( + MessageFile.id == id, + ).first() + + # get tool file id + tool_file_id = message_file.url.split('/')[-1] + # trim extension + tool_file_id = tool_file_id.split('.')[0] + + tool_file: ToolFile = db.session.query(ToolFile).filter( + ToolFile.id == tool_file_id, + ).first() + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_generator_by_message_file_id(id: str) -> Union[Tuple[Generator, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + message_file: MessageFile = db.session.query(MessageFile).filter( + MessageFile.id == id, + ).first() + + # get tool file id + tool_file_id = message_file.url.split('/')[-1] + # trim extension + tool_file_id = tool_file_id.split('.')[0] + + tool_file: ToolFile = db.session.query(ToolFile).filter( + ToolFile.id == tool_file_id, + ).first() + + if not tool_file: + return None + + generator = storage.load_stream(tool_file.file_key) + + return generator, tool_file.mimetype + +# init tool_file_parser +from core.file.tool_file_parser import tool_file_manager +tool_file_manager['manager'] = ToolFileManager diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py new file mode 100644 index 0000000000..4b88d9b9e8 --- /dev/null +++ b/api/core/tools/tool_manager.py @@ -0,0 +1,448 @@ +from typing import List, Dict, Any, Tuple, Union +from os import listdir, path + +from core.tools.entities.tool_entities import ToolInvokeMessage, ApiProviderAuthType, ToolProviderCredentials +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.tool.api_tool import ApiTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.entities.constant import DEFAULT_PROVIDERS +from core.tools.entities.common_entities import I18nObject +from core.tools.errors import ToolProviderNotFoundError +from core.tools.provider.api_tool_provider import ApiBasedToolProviderController +from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity +from core.tools.entities.user_entities import UserToolProvider +from core.tools.utils.configration import ToolConfiguration +from core.tools.utils.encoder import serialize_base_model_dict +from core.tools.provider.builtin._positions import BuiltinToolProviderSort + +from core.model_runtime.entities.message_entities import PromptMessage +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler + +from extensions.ext_database import db + +from models.tools import ApiToolProvider, BuiltinToolProvider + +import importlib +import logging +import json +import mimetypes + +logger = logging.getLogger(__name__) + +_builtin_providers = {} + +class ToolManager: + @staticmethod + def invoke( + provider: str, + tool_id: str, + tool_name: str, + tool_parameters: Dict[str, Any], + credentials: Dict[str, Any], + prompt_messages: List[PromptMessage], + ) -> List[ToolInvokeMessage]: + """ + invoke the assistant + + :param provider: the name of the provider + :param tool_id: the id of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool + :param credentials: the credentials of the tool + :param prompt_messages: the prompt messages that the tool can use + + :return: the messages that the tool wants to send to the user + """ + provider_entity: ToolProviderController = None + if provider == DEFAULT_PROVIDERS.API_BASED: + provider_entity = ApiBasedToolProviderController() + elif provider == DEFAULT_PROVIDERS.APP_BASED: + provider_entity = AppBasedToolProviderEntity() + + if provider_entity is None: + # fetch the provider from .provider.builtin + py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py') + spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + # get all the classes in the module + classes = [ x for _, x in vars(mod).items() + if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController) + ] + if len(classes) == 0: + raise ToolProviderNotFoundError(f'provider {provider} not found') + if len(classes) > 1: + raise ToolProviderNotFoundError(f'multiple providers found for {provider}') + + provider_entity = classes[0]() + + return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages) + + @staticmethod + def get_builtin_provider(provider: str) -> BuiltinToolProviderController: + global _builtin_providers + """ + get the builtin provider + + :param provider: the name of the provider + :return: the provider + """ + if len(_builtin_providers) == 0: + # init the builtin providers + ToolManager.list_builtin_providers() + + if provider not in _builtin_providers: + raise ToolProviderNotFoundError(f'builtin provider {provider} not found') + + return _builtin_providers[provider] + + @staticmethod + def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool: + """ + get the builtin tool + + :param provider: the name of the provider + :param tool_name: the name of the tool + + :return: the provider, the tool + """ + provider_controller = ToolManager.get_builtin_provider(provider) + tool = provider_controller.get_tool(tool_name) + + return tool + + @staticmethod + def get_tool(provider_type: str, provider_id: str, tool_name: str, tanent_id: str = None) \ + -> Union[BuiltinTool, ApiTool]: + """ + get the tool + + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool + + :return: the tool + """ + if provider_type == 'builtin': + return ToolManager.get_builtin_tool(provider_id, tool_name) + elif provider_type == 'api': + if tanent_id is None: + raise ValueError('tanent id is required for api provider') + api_provider, _ = ToolManager.get_api_provider_controller(tanent_id, provider_id) + return api_provider.get_tool(tool_name) + elif provider_type == 'app': + raise NotImplementedError('app provider not implemented') + else: + raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + + @staticmethod + def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tanent_id, + agent_callback: DifyAgentCallbackHandler = None) \ + -> Union[BuiltinTool, ApiTool]: + """ + get the tool runtime + + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool + + :return: the tool + """ + if provider_type == 'builtin': + builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name) + + # check if the builtin tool need credentials + provider_controller = ToolManager.get_builtin_provider(provider_name) + if not provider_controller.need_credentials: + return builtin_tool.fork_tool_runtime(meta={ + 'tenant_id': tanent_id, + 'credentials': {}, + }, agent_callback=agent_callback) + + # get credentials + builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tanent_id, + BuiltinToolProvider.provider == provider_name, + ).first() + + if builtin_provider is None: + raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found') + + # decrypt the credentials + credentials = builtin_provider.credentials + controller = ToolManager.get_builtin_provider(provider_name) + tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=controller) + + decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials) + + return builtin_tool.fork_tool_runtime(meta={ + 'tenant_id': tanent_id, + 'credentials': decrypted_credentails, + 'runtime_parameters': {} + }, agent_callback=agent_callback) + + elif provider_type == 'api': + if tanent_id is None: + raise ValueError('tanent id is required for api provider') + + api_provider, credentials = ToolManager.get_api_provider_controller(tanent_id, provider_name) + + # decrypt the credentials + tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=api_provider) + decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials) + + return api_provider.get_tool(tool_name).fork_tool_runtime(meta={ + 'tenant_id': tanent_id, + 'credentials': decrypted_credentails, + }) + elif provider_type == 'app': + raise NotImplementedError('app provider not implemented') + else: + raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + + @staticmethod + def get_builtin_provider_icon(provider: str) -> Tuple[str, str]: + """ + get the absolute path of the icon of the builtin provider + + :param provider: the name of the provider + + :return: the absolute path of the icon, the mime type of the icon + """ + # get provider + provider_controller = ToolManager.get_builtin_provider(provider) + + absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon) + # check if the icon exists + if not path.exists(absolute_path): + raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found') + + # get the mime type + mime_type, _ = mimetypes.guess_type(absolute_path) + mime_type = mime_type or 'application/octet-stream' + + return absolute_path, mime_type + + @staticmethod + def list_builtin_providers() -> List[BuiltinToolProviderController]: + global _builtin_providers + + # use cache first + if len(_builtin_providers) > 0: + return list(_builtin_providers.values()) + + builtin_providers = [] + for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): + if provider.startswith('__'): + continue + + if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)): + if provider.startswith('__'): + continue + + py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py') + spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + # load all classes + classes = [ + obj for name, obj in vars(mod).items() + if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController) + ] + if len(classes) == 0: + raise ToolProviderNotFoundError(f'provider {provider} not found') + if len(classes) > 1: + raise ToolProviderNotFoundError(f'multiple providers found for {provider}') + + # init provider + provider_class = classes[0] + builtin_providers.append(provider_class()) + + # cache the builtin providers + for provider in builtin_providers: + _builtin_providers[provider.identity.name] = provider + return builtin_providers + + @staticmethod + def user_list_providers( + user_id: str, + tenant_id: str, + ) -> List[UserToolProvider]: + result_providers: Dict[str, UserToolProvider] = {} + # get builtin providers + builtin_providers = ToolManager.list_builtin_providers() + # append builtin providers + for provider in builtin_providers: + result_providers[provider.identity.name] = UserToolProvider( + id=provider.identity.name, + author=provider.identity.author, + name=provider.identity.name, + description=I18nObject( + en_US=provider.identity.description.en_US, + zh_Hans=provider.identity.description.zh_Hans, + ), + icon=provider.identity.icon, + label=I18nObject( + en_US=provider.identity.label.en_US, + zh_Hans=provider.identity.label.zh_Hans, + ), + type=UserToolProvider.ProviderType.BUILTIN, + team_credentials={}, + is_team_authorization=False, + ) + + # get credentials schema + schema = provider.get_credentails_schema() + for name, value in schema.items(): + result_providers[provider.identity.name].team_credentials[name] = \ + ToolProviderCredentials.CredentialsType.defaut(value.type) + + # check if the provider need credentials + if not provider.need_credentials: + result_providers[provider.identity.name].is_team_authorization = True + result_providers[provider.identity.name].allow_delete = False + + # get db builtin providers + db_builtin_providers: List[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ + filter(BuiltinToolProvider.tenant_id == tenant_id).all() + + for db_builtin_provider in db_builtin_providers: + # add provider into providers + credentails = db_builtin_provider.credentials + provider_name = db_builtin_provider.provider + result_providers[provider_name].is_team_authorization = True + + # package builtin tool provider controller + controller = ToolManager.get_builtin_provider(provider_name) + + # init tool configuration + tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + # decrypt the credentials and mask the credentials + decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials=credentails) + masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentails) + + result_providers[provider_name].team_credentials = masked_credentials + + # get db api providers + db_api_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider). \ + filter(ApiToolProvider.tenant_id == tenant_id).all() + + for db_api_provider in db_api_providers: + username = 'Anonymous' + try: + username = db_api_provider.user.name + except Exception as e: + logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}') + # add provider into providers + credentails = db_api_provider.credentials + provider_name = db_api_provider.name + result_providers[provider_name] = UserToolProvider( + id=db_api_provider.id, + author=username, + name=db_api_provider.name, + description=I18nObject( + en_US=db_api_provider.description, + zh_Hans=db_api_provider.description, + ), + icon=db_api_provider.icon, + label=I18nObject( + en_US=db_api_provider.name, + zh_Hans=db_api_provider.name, + ), + type=UserToolProvider.ProviderType.API, + team_credentials={}, + is_team_authorization=True, + ) + + # package tool provider controller + controller = ApiBasedToolProviderController.from_db( + db_provider=db_api_provider, + auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + ) + + # init tool configuration + tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + + # decrypt the credentials and mask the credentials + decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials=credentails) + masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentails) + + result_providers[provider_name].team_credentials = masked_credentials + + return BuiltinToolProviderSort.sort(list(result_providers.values())) + + @staticmethod + def get_api_provider_controller(tanent_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]: + """ + get the api provider + + :param provider_name: the name of the provider + + :return: the provider controller, the credentials + """ + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.id == provider_id, + ApiToolProvider.tenant_id == tanent_id, + ).first() + + if provider is None: + raise ToolProviderNotFoundError(f'api provider {provider_id} not found') + + controller = ApiBasedToolProviderController.from_db( + provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + ) + controller.load_bundled_tools(provider.tools) + + return controller, provider.credentials + + @staticmethod + def user_get_api_provider(provider: str, tenant_id: str) -> dict: + """ + get api provider + """ + """ + get tool provider + """ + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ).first() + + if provider is None: + raise ValueError(f'yout have not added provider {provider}') + + try: + credentials = json.loads(provider.credentials_str) or {} + except: + credentials = {} + + # package tool provider controller + controller = ApiBasedToolProviderController.from_db( + provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + ) + # init tool configuration + tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + + decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials) + masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentails) + + try: + icon = json.loads(provider.icon) + except: + icon = { + "background": "#252525", + "content": "\ud83d\ude01" + } + + return json.loads(serialize_base_model_dict({ + 'schema_type': provider.schema_type, + 'schema': provider.schema, + 'tools': provider.tools, + 'icon': icon, + 'description': provider.description, + 'credentials': masked_credentials, + 'privacy_policy': provider.privacy_policy + })) \ No newline at end of file diff --git a/api/core/tools/utils/configration.py b/api/core/tools/utils/configration.py new file mode 100644 index 0000000000..a0ff8d754f --- /dev/null +++ b/api/core/tools/utils/configration.py @@ -0,0 +1,77 @@ +from typing import Dict, Any +from pydantic import BaseModel + +from core.tools.entities.tool_entities import ToolProviderCredentials +from core.tools.provider.tool_provider import ToolProviderController +from core.helper import encrypter + +class ToolConfiguration(BaseModel): + tenant_id: str + provider_controller: ToolProviderController + + def _deep_copy(self, credentails: Dict[str, str]) -> Dict[str, str]: + """ + deep copy credentials + """ + return {key: value for key, value in credentails.items()} + + def encrypt_tool_credentials(self, credentails: Dict[str, str]) -> Dict[str, str]: + """ + encrypt tool credentials with tanent id + + return a deep copy of credentials with encrypted values + """ + credentials = self._deep_copy(credentails) + + # get fields need to be decrypted + fields = self.provider_controller.get_credentails_schema() + for field_name, field in fields.items(): + if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: + if field_name in credentials: + encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name]) + credentials[field_name] = encrypted + + return credentials + + def mask_tool_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + credentials = self._deep_copy(credentials) + + # get fields need to be decrypted + fields = self.provider_controller.get_credentails_schema() + for field_name, field in fields.items(): + if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: + if field_name in credentials: + if len(credentials[field_name]) > 6: + credentials[field_name] = \ + credentials[field_name][:2] + \ + '*' * (len(credentials[field_name]) - 4) +\ + credentials[field_name][-2:] + else: + credentials[field_name] = '*' * len(credentials[field_name]) + + return credentials + + def decrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]: + """ + decrypt tool credentials with tanent id + + return a deep copy of credentials with decrypted values + """ + credentials = self._deep_copy(credentials) + + # get fields need to be decrypted + fields = self.provider_controller.get_credentails_schema() + for field_name, field in fields.items(): + if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: + if field_name in credentials: + try: + credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name]) + except: + pass + + return credentials \ No newline at end of file diff --git a/api/core/tools/utils/encoder.py b/api/core/tools/utils/encoder.py new file mode 100644 index 0000000000..eaf9b6bedc --- /dev/null +++ b/api/core/tools/utils/encoder.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel +from enum import Enum +from typing import List + +def serialize_base_model_array(l: List[BaseModel]) -> str: + class _BaseModel(BaseModel): + __root__: List[BaseModel] + + """ + {"__root__": [BaseModel, BaseModel, ...]} + """ + return _BaseModel(__root__=l).json() + +def serialize_base_model_dict(b: dict) -> str: + class _BaseModel(BaseModel): + __root__: dict + + """ + {"__root__": {BaseModel}} + """ + return _BaseModel(__root__=b).json() diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py new file mode 100644 index 0000000000..f61ac1cbfa --- /dev/null +++ b/api/core/tools/utils/parser.py @@ -0,0 +1,341 @@ + +from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_entities import ToolParamter, ToolParamterOption, ApiProviderSchemaType +from core.tools.entities.common_entities import I18nObject +from core.tools.errors import ToolProviderNotFoundError, ToolNotSupportedError, \ + ToolApiSchemaError + +from typing import List, Tuple + +from yaml import FullLoader, load +from json import loads as json_loads, dumps as json_dumps +from requests import get + +class ApiBasedToolSchemaParser: + @staticmethod + def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]: + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + # set description to extra_info + if 'description' in openapi['info']: + extra_info['description'] = openapi['info']['description'] + else: + extra_info['description'] = '' + + if len(openapi['servers']) == 0: + raise ToolProviderNotFoundError('No server found in the openapi yaml.') + + server_url = openapi['servers'][0]['url'] + + # list all interfaces + interfaces = [] + for path, path_item in openapi['paths'].items(): + methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace'] + for method in methods: + if method in path_item: + interfaces.append({ + 'path': path, + 'method': method, + 'operation': path_item[method], + }) + + # get all parameters + bundles = [] + for interface in interfaces: + # convert parameters + parameters = [] + if 'parameters' in interface['operation']: + for parameter in interface['operation']['parameters']: + parameters.append(ToolParamter( + name=parameter['name'], + label=I18nObject( + en_US=parameter['name'], + zh_Hans=parameter['name'] + ), + human_description=I18nObject( + en_US=parameter.get('description', ''), + zh_Hans=parameter.get('description', '') + ), + type=ToolParamter.ToolParameterType.STRING, + required=parameter.get('required', False), + form=ToolParamter.ToolParameterForm.LLM, + llm_description=parameter.get('description'), + default=parameter['default'] if 'default' in parameter else None, + )) + # create tool bundle + # check if there is a request body + if 'requestBody' in interface['operation']: + request_body = interface['operation']['requestBody'] + if 'content' in request_body: + for content_type, content in request_body['content'].items(): + # if there is a reference, get the reference and overwrite the content + if 'schema' not in content: + content + + if '$ref' in content['schema']: + # get the reference + root = openapi + reference = content['schema']['$ref'].split('/')[1:] + for ref in reference: + root = root[ref] + # overwrite the content + interface['operation']['requestBody']['content'][content_type]['schema'] = root + # parse body parameters + if 'schema' in interface['operation']['requestBody']['content'][content_type]: + body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] + required = body_schema['required'] if 'required' in body_schema else [] + properties = body_schema['properties'] if 'properties' in body_schema else {} + for name, property in properties.items(): + parameters.append(ToolParamter( + name=name, + label=I18nObject( + en_US=name, + zh_Hans=name + ), + human_description=I18nObject( + en_US=property['description'] if 'description' in property else '', + zh_Hans=property['description'] if 'description' in property else '' + ), + type=ToolParamter.ToolParameterType.STRING, + required=name in required, + form=ToolParamter.ToolParameterForm.LLM, + llm_description=property['description'] if 'description' in property else '', + default=property['default'] if 'default' in property else None, + )) + + # check if parameters is duplicated + parameters_count = {} + for parameter in parameters: + if parameter.name not in parameters_count: + parameters_count[parameter.name] = 0 + parameters_count[parameter.name] += 1 + for name, count in parameters_count.items(): + if count > 1: + warning['duplicated_parameter'] = f'Parameter {name} is duplicated.' + + bundles.append(ApiBasedToolBundle( + server_url=server_url + interface['path'], + method=interface['method'], + summary=interface['operation']['summary'] if 'summary' in interface['operation'] else None, + operation_id=interface['operation']['operationId'], + parameters=parameters, + author='', + icon=None, + openapi=interface['operation'], + )) + + return bundles + + @staticmethod + def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]: + """ + parse openapi yaml to tool bundle + + :param yaml: the yaml string + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + openapi: dict = load(yaml, Loader=FullLoader) + if openapi is None: + raise ToolApiSchemaError('Invalid openapi yaml.') + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + @staticmethod + def parse_openapi_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]: + """ + parse openapi yaml to tool bundle + + :param yaml: the yaml string + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + openapi: dict = json_loads(json) + if openapi is None: + raise ToolApiSchemaError('Invalid openapi json.') + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + @staticmethod + def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict: + """ + parse swagger to openapi + + :param swagger: the swagger dict + :return: the openapi dict + """ + # convert swagger to openapi + info = swagger.get('info', { + 'title': 'Swagger', + 'description': 'Swagger', + 'version': '1.0.0' + }) + + servers = swagger.get('servers', []) + + if len(servers) == 0: + raise ToolApiSchemaError('No server found in the swagger yaml.') + + openapi = { + 'openapi': '3.0.0', + 'info': { + 'title': info.get('title', 'Swagger'), + 'description': info.get('description', 'Swagger'), + 'version': info.get('version', '1.0.0') + }, + 'servers': swagger['servers'], + 'paths': {}, + 'components': { + 'schemas': {} + } + } + + # check paths + if 'paths' not in swagger or len(swagger['paths']) == 0: + raise ToolApiSchemaError('No paths found in the swagger yaml.') + + # convert paths + for path, path_item in swagger['paths'].items(): + openapi['paths'][path] = {} + for method, operation in path_item.items(): + if 'operationId' not in operation: + raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.') + + if 'summary' not in operation or len(operation['summary']) == 0: + warning['missing_summary'] = f'No summary found in operation {method} {path}.' + + if 'description' not in operation or len(operation['description']) == 0: + warning['missing_description'] = f'No description found in operation {method} {path}.' + + openapi['paths'][path][method] = { + 'operationId': operation['operationId'], + 'summary': operation.get('summary', ''), + 'description': operation.get('description', ''), + 'parameters': operation.get('parameters', []), + 'responses': operation.get('responses', {}), + } + + if 'requestBody' in operation: + openapi['paths'][path][method]['requestBody'] = operation['requestBody'] + + # convert definitions + for name, definition in swagger['definitions'].items(): + openapi['components']['schemas'][name] = definition + + return openapi + + @staticmethod + def parse_swagger_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]: + """ + parse swagger yaml to tool bundle + + :param yaml: the yaml string + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + swagger: dict = load(yaml, Loader=FullLoader) + + openapi = ApiBasedToolSchemaParser.parse_swagger_to_openapi(swagger, extra_info=extra_info, warning=warning) + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + @staticmethod + def parse_swagger_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]: + """ + parse swagger yaml to tool bundle + + :param yaml: the yaml string + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + swagger: dict = json_loads(json) + + openapi = ApiBasedToolSchemaParser.parse_swagger_to_openapi(swagger, extra_info=extra_info, warning=warning) + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + @staticmethod + def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> List[ApiBasedToolBundle]: + """ + parse openapi plugin yaml to tool bundle + + :param json: the json string + :return: the tool bundle + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + try: + openai_plugin = json_loads(json) + api = openai_plugin['api'] + api_url = api['url'] + api_type = api['type'] + except: + raise ToolProviderNotFoundError('Invalid openai plugin json.') + + if api_type != 'openapi': + raise ToolNotSupportedError('Only openapi is supported now.') + + # get openapi yaml + response = get(api_url, headers={ + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ' + }, timeout=5) + + if response.status_code != 200: + raise ToolProviderNotFoundError('cannot get openapi yaml from url.') + + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) + + @staticmethod + def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> Tuple[List[ApiBasedToolBundle], str]: + """ + auto parse to tool bundle + + :param content: the content + :return: tools bundle, schema_type + """ + warning = warning if warning is not None else {} + extra_info = extra_info if extra_info is not None else {} + + json_possible = False + content = content.strip() + + if content.startswith('{') and content.endswith('}'): + json_possible = True + + if json_possible: + try: + return ApiBasedToolSchemaParser.parse_openapi_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ + ApiProviderSchemaType.OPENAPI.value + except: + pass + + try: + return ApiBasedToolSchemaParser.parse_swagger_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ + ApiProviderSchemaType.SWAGGER.value + except: + pass + try: + return ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ + ApiProviderSchemaType.OPENAI_PLUGIN.value + except: + pass + else: + try: + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ + ApiProviderSchemaType.OPENAPI.value + except: + pass + + try: + return ApiBasedToolSchemaParser.parse_swagger_yaml_to_tool_bundle(content, extra_info=extra_info, warning=warning), \ + ApiProviderSchemaType.SWAGGER.value + except: + pass + + raise ToolApiSchemaError('Invalid api schema.') \ No newline at end of file diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py new file mode 100644 index 0000000000..b4cfc19871 --- /dev/null +++ b/api/core/tools/utils/web_reader_tool.py @@ -0,0 +1,446 @@ +import hashlib +import json +import os +import re +import site +import subprocess +import tempfile +import unicodedata +from contextlib import contextmanager +from typing import Type, Any + +import requests +from bs4 import BeautifulSoup, NavigableString, Comment, CData +from langchain.chains import RefineDocumentsChain +from langchain.chains.summarize import refine_prompts +from langchain.schema import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.tools.base import BaseTool +from newspaper import Article +from pydantic import BaseModel, Field +from regex import regex + +from core.chain.llm_chain import LLMChain +from core.data_loader import file_extractor +from core.data_loader.file_extractor import FileExtractor +from core.entities.application_entities import ModelConfigEntity + +FULL_TEMPLATE = """ +TITLE: {title} +AUTHORS: {authors} +PUBLISH DATE: {publish_date} +TOP_IMAGE_URL: {top_image} +TEXT: + +{text} +""" + + +class WebReaderToolInput(BaseModel): + url: str = Field(..., description="URL of the website to read") + summary: bool = Field( + default=False, + description="When the user's question requires extracting the summarizing content of the webpage, " + "set it to true." + ) + cursor: int = Field( + default=0, + description="Start reading from this character." + "Use when the first response was truncated" + "and you want to continue reading the page." + "The value cannot exceed 24000.", + ) + + +class WebReaderTool(BaseTool): + """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool.""" + + name: str = "web_reader" + args_schema: Type[BaseModel] = WebReaderToolInput + description: str = "use this to read a website. " \ + "If you can answer the question based on the information provided, " \ + "there is no need to use." + page_contents: str = None + url: str = None + max_chunk_length: int = 4000 + summary_chunk_tokens: int = 4000 + summary_chunk_overlap: int = 0 + summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] + continue_reading: bool = True + model_config: ModelConfigEntity + model_parameters: dict[str, Any] + + def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: + try: + if not self.page_contents or self.url != url: + page_contents = get_url(url) + self.page_contents = page_contents + self.url = url + else: + page_contents = self.page_contents + except Exception as e: + return f'Read this website failed, caused by: {str(e)}.' + + if summary: + character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + chunk_size=self.summary_chunk_tokens, + chunk_overlap=self.summary_chunk_overlap, + separators=self.summary_separators + ) + + texts = character_splitter.split_text(page_contents) + docs = [Document(page_content=t) for t in texts] + + if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'): + return "No content found." + + # only use first 5 docs + if len(docs) > 5: + docs = docs[:5] + + chain = self.get_summary_chain() + try: + page_contents = chain.run(docs) + except Exception as e: + return f'Read this website failed, caused by: {str(e)}.' + else: + page_contents = page_result(page_contents, cursor, self.max_chunk_length) + + if self.continue_reading and len(page_contents) >= self.max_chunk_length: + page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ + f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \ + f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." + + return page_contents + + async def _arun(self, url: str) -> str: + raise NotImplementedError + + def get_summary_chain(self) -> RefineDocumentsChain: + initial_chain = LLMChain( + model_config=self.model_config, + prompt=refine_prompts.PROMPT, + parameters=self.model_parameters + ) + refine_chain = LLMChain( + model_config=self.model_config, + prompt=refine_prompts.REFINE_PROMPT, + parameters=self.model_parameters + ) + return RefineDocumentsChain( + initial_llm_chain=initial_chain, + refine_llm_chain=refine_chain, + document_variable_name="text", + initial_response_name="existing_answer", + callbacks=self.callbacks + ) + + +def page_result(text: str, cursor: int, max_length: int) -> str: + """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" + return text[cursor: cursor + max_length] + + +def get_url(url: str, user_agent: str = None) -> str: + """Fetch URL and return the contents as a string.""" + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + if user_agent: + headers["User-Agent"] = user_agent + + supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] + + head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) + + if head_response.status_code != 200: + return "URL returned status code {}.".format(head_response.status_code) + + # check content-type + main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip() + if main_content_type not in supported_content_types: + return "Unsupported content-type [{}] of URL.".format(main_content_type) + + if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: + return FileExtractor.load_from_url(url, return_text=True) + + response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) + a = extract_using_readabilipy(response.text) + + if not a['plain_text'] or not a['plain_text'].strip(): + return get_url_from_newspaper3k(url) + + res = FULL_TEMPLATE.format( + title=a['title'], + authors=a['byline'], + publish_date=a['date'], + top_image="", + text=a['plain_text'] if a['plain_text'] else "", + ) + + return res + + +def get_url_from_newspaper3k(url: str) -> str: + + a = Article(url) + a.download() + a.parse() + + res = FULL_TEMPLATE.format( + title=a.title, + authors=a.authors, + publish_date=a.publish_date, + top_image=a.top_image, + text=a.text, + ) + + return res + + +def extract_using_readabilipy(html): + with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: + f_html.write(html) + f_html.close() + html_path = f_html.name + + # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file + article_json_path = html_path + ".json" + jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') + with chdir(jsdir): + subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) + + # Read output of call to Readability.parse() from JSON file and return as Python dictionary + with open(article_json_path, "r", encoding="utf-8") as json_file: + input_json = json.loads(json_file.read()) + + # Deleting files after processing + os.unlink(article_json_path) + os.unlink(html_path) + + article_json = { + "title": None, + "byline": None, + "date": None, + "content": None, + "plain_content": None, + "plain_text": None + } + # Populate article fields from readability fields where present + if input_json: + if "title" in input_json and input_json["title"]: + article_json["title"] = input_json["title"] + if "byline" in input_json and input_json["byline"]: + article_json["byline"] = input_json["byline"] + if "date" in input_json and input_json["date"]: + article_json["date"] = input_json["date"] + if "content" in input_json and input_json["content"]: + article_json["content"] = input_json["content"] + article_json["plain_content"] = plain_content(article_json["content"], False, False) + article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) + if "textContent" in input_json and input_json["textContent"]: + article_json["plain_text"] = input_json["textContent"] + article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) + + return article_json + + +def find_module_path(module_name): + for package_path in site.getsitepackages(): + potential_path = os.path.join(package_path, module_name) + if os.path.exists(potential_path): + return potential_path + + return None + +@contextmanager +def chdir(path): + """Change directory in context and return to original on exit""" + # From https://stackoverflow.com/a/37996581, couldn't find a built-in + original_path = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(original_path) + + +def extract_text_blocks_as_plain_text(paragraph_html): + # Load article as DOM + soup = BeautifulSoup(paragraph_html, 'html.parser') + # Select all lists + list_elements = soup.find_all(['ul', 'ol']) + # Prefix text in all list items with "* " and make lists paragraphs + for list_element in list_elements: + plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) + list_element.string = plain_items + list_element.name = "p" + # Select all text blocks + text_blocks = [s.parent for s in soup.find_all(string=True)] + text_blocks = [plain_text_leaf_node(block) for block in text_blocks] + # Drop empty paragraphs + text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) + return text_blocks + + +def plain_text_leaf_node(element): + # Extract all text, stripped of any child HTML elements and normalise it + plain_text = normalise_text(element.get_text()) + if plain_text != "" and element.name == "li": + plain_text = "* {}, ".format(plain_text) + if plain_text == "": + plain_text = None + if "data-node-index" in element.attrs: + plain = {"node_index": element["data-node-index"], "text": plain_text} + else: + plain = {"text": plain_text} + return plain + + +def plain_content(readability_content, content_digests, node_indexes): + # Load article as DOM + soup = BeautifulSoup(readability_content, 'html.parser') + # Make all elements plain + elements = plain_elements(soup.contents, content_digests, node_indexes) + if node_indexes: + # Add node index attributes to nodes + elements = [add_node_indexes(element) for element in elements] + # Replace article contents with plain elements + soup.contents = elements + return str(soup) + + +def plain_elements(elements, content_digests, node_indexes): + # Get plain content versions of all elements + elements = [plain_element(element, content_digests, node_indexes) + for element in elements] + if content_digests: + # Add content digest attribute to nodes + elements = [add_content_digest(element) for element in elements] + return elements + + +def plain_element(element, content_digests, node_indexes): + # For lists, we make each item plain text + if is_leaf(element): + # For leaf node elements, extract the text content, discarding any HTML tags + # 1. Get element contents as text + plain_text = element.get_text() + # 2. Normalise the extracted text string to a canonical representation + plain_text = normalise_text(plain_text) + # 3. Update element content to be plain text + element.string = plain_text + elif is_text(element): + if is_non_printing(element): + # The simplified HTML may have come from Readability.js so might + # have non-printing text (e.g. Comment or CData). In this case, we + # keep the structure, but ensure that the string is empty. + element = type(element)("") + else: + plain_text = element.string + plain_text = normalise_text(plain_text) + element = type(element)(plain_text) + else: + # If not a leaf node or leaf type call recursively on child nodes, replacing + element.contents = plain_elements(element.contents, content_digests, node_indexes) + return element + + +def add_node_indexes(element, node_index="0"): + # Can't add attributes to string types + if is_text(element): + return element + # Add index to current element + element["data-node-index"] = node_index + # Add index to child elements + for local_idx, child in enumerate( + [c for c in element.contents if not is_text(c)], start=1): + # Can't add attributes to leaf string types + child_index = "{stem}.{local}".format( + stem=node_index, local=local_idx) + add_node_indexes(child, node_index=child_index) + return element + + +def normalise_text(text): + """Normalise unicode and whitespace.""" + # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them + text = strip_control_characters(text) + text = normalise_unicode(text) + text = normalise_whitespace(text) + return text + + +def strip_control_characters(text): + """Strip out unicode control characters which might break the parsing.""" + # Unicode control characters + # [Cc]: Other, Control [includes new lines] + # [Cf]: Other, Format + # [Cn]: Other, Not Assigned + # [Co]: Other, Private Use + # [Cs]: Other, Surrogate + control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs']) + retained_chars = ['\t', '\n', '\r', '\f'] + + # Remove non-printing control characters + return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) + + +def normalise_unicode(text): + """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible.""" + normal_form = "NFKC" + text = unicodedata.normalize(normal_form, text) + return text + + +def normalise_whitespace(text): + """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" + text = regex.sub(r"\s+", " ", text) + # Remove leading and trailing whitespace + text = text.strip() + return text + +def is_leaf(element): + return (element.name in ['p', 'li']) + + +def is_text(element): + return isinstance(element, NavigableString) + + +def is_non_printing(element): + return any(isinstance(element, _e) for _e in [Comment, CData]) + + +def add_content_digest(element): + if not is_text(element): + element["data-content-digest"] = content_digest(element) + return element + + +def content_digest(element): + if is_text(element): + # Hash + trimmed_string = element.string.strip() + if trimmed_string == "": + digest = "" + else: + digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() + else: + contents = element.contents + num_contents = len(contents) + if num_contents == 0: + # No hash when no child elements exist + digest = "" + elif num_contents == 1: + # If single child, use digest of child + digest = content_digest(contents[0]) + else: + # Build content digest from the "non-empty" digests of child nodes + digest = hashlib.sha256() + child_digests = list( + filter(lambda x: x != "", [content_digest(content) for content in contents])) + for child in child_digests: + digest.update(child.encode('utf-8')) + digest = digest.hexdigest() + return digest diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 4f784c6648..2b202c53d0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -50,17 +50,24 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: return dataset_ids agent_mode = app_model_config.agent_mode_dict - if agent_mode.get('enabled') is False: - return dataset_ids - if not agent_mode.get('tools'): - return dataset_ids - - tools = agent_mode.get('tools') + tools = agent_mode.get('tools', []) or [] for tool in tools: + if len(list(tool.keys())) != 1: + continue + tool_type = list(tool.keys())[0] tool_config = list(tool.values())[0] if tool_type == "dataset": dataset_ids.add(tool_config.get("id")) + # get dataset from dataset_configs + dataset_configs = app_model_config.dataset_configs_dict + datasets = dataset_configs.get('datasets', {}) or {} + for dataset in datasets.get('datasets', []) or []: + keys = list(dataset.keys()) + if len(keys) == 1 and keys[0] == 'dataset': + if dataset['dataset'].get('id'): + dataset_ids.add(dataset['dataset'].get('id')) + return dataset_ids diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index f303c37864..9030b2fe4d 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -40,6 +40,7 @@ app_detail_fields = { 'id': fields.String, 'name': fields.String, 'mode': fields.String, + 'is_agent': fields.Boolean, 'icon': fields.String, 'icon_background': fields.String, 'enable_site': fields.Boolean, @@ -64,6 +65,7 @@ app_partial_fields = { 'id': fields.String, 'name': fields.String, 'mode': fields.String, + 'is_agent': fields.Boolean, 'icon': fields.String, 'icon_background': fields.String, 'enable_site': fields.Boolean, @@ -120,11 +122,13 @@ app_detail_fields_with_site = { 'enable_api': fields.Boolean, 'api_rpm': fields.Integer, 'api_rph': fields.Integer, + 'is_agent': fields.Boolean, 'is_demo': fields.Boolean, 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), 'site': fields.Nested(site_fields), 'api_base_url': fields.String, - 'created_at': TimestampField + 'created_at': TimestampField, + 'deleted_tools': fields.List(fields.String), } app_site_fields = { diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 5ab73115d8..557f047a95 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -39,6 +39,20 @@ message_file_fields = { 'id': fields.String, 'type': fields.String, 'url': fields.String, + 'belongs_to': fields.String(default='user'), +} + +agent_thought_fields = { + 'id': fields.String, + 'chain_id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'thought': fields.String, + 'tool': fields.String, + 'tool_input': fields.String, + 'created_at': TimestampField, + 'observation': fields.String, + 'files': fields.List(fields.String) } message_detail_fields = { @@ -58,6 +72,7 @@ message_detail_fields = { 'annotation': fields.Nested(annotation_fields, allow_null=True), 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), 'created_at': TimestampField, + 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), } diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 95b2088c2d..f1c2377f46 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -17,6 +17,7 @@ installed_app_fields = { 'last_used_at': TimestampField, 'editable': fields.Boolean, 'uninstallable': fields.Boolean, + 'is_agent': fields.Boolean, } installed_app_list_fields = { diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 5995abbcfa..397b9795b8 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -25,20 +25,57 @@ retriever_resource_fields = { 'created_at': TimestampField } +feedback_fields = { + 'rating': fields.String +} + +agent_thought_fields = { + 'id': fields.String, + 'chain_id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'thought': fields.String, + 'tool': fields.String, + 'tool_input': fields.String, + 'created_at': TimestampField, + 'observation': fields.String, + 'files': fields.List(fields.String) +} + +retriever_resource_fields = { + 'id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'dataset_id': fields.String, + 'dataset_name': fields.String, + 'document_id': fields.String, + 'document_name': fields.String, + 'data_source_type': fields.String, + 'segment_id': fields.String, + 'score': fields.Float, + 'hit_count': fields.Integer, + 'word_count': fields.Integer, + 'segment_position': fields.Integer, + 'index_node_hash': fields.String, + 'content': fields.String, + 'created_at': TimestampField +} + message_fields = { 'id': fields.String, 'conversation_id': fields.String, 'inputs': fields.Raw, 'query': fields.String, 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField + 'created_at': TimestampField, + 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), + 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files') } message_infinite_scroll_pagination_fields = { 'limit': fields.Integer, 'has_more': fields.Boolean, 'data': fields.List(fields.Nested(message_fields)) -} \ No newline at end of file +} diff --git a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py new file mode 100644 index 0000000000..a498c90460 --- /dev/null +++ b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py @@ -0,0 +1,34 @@ +"""rename api provider description + +Revision ID: 00bacef91f18 +Revises: 8ec536f3c800 +Create Date: 2024-01-07 04:07:34.482983 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '00bacef91f18' +down_revision = '8ec536f3c800' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', sa.Text(), nullable=False)) + batch_op.drop_column('description_str') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False)) + batch_op.drop_column('description') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py b/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py new file mode 100644 index 0000000000..8d304ebfb9 --- /dev/null +++ b/api/migrations/versions/053da0c1d756_add_api_tool_privacy.py @@ -0,0 +1,51 @@ +"""add api tool privacy + +Revision ID: 053da0c1d756 +Revises: 4829e54d2fee +Create Date: 2024-01-12 06:47:21.656262 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '053da0c1d756' +down_revision = '4829e54d2fee' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_conversation_variables', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('variables_str', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey') + ) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('privacy_policy', sa.String(length=255), nullable=True)) + batch_op.alter_column('icon', + existing_type=sa.VARCHAR(length=256), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('icon', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=256), + existing_nullable=False) + batch_op.drop_column('privacy_policy') + + op.drop_table('tool_conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py new file mode 100644 index 0000000000..1b1d77055a --- /dev/null +++ b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py @@ -0,0 +1,32 @@ +"""remove tool id from model invoke + +Revision ID: 114eed84c228 +Revises: c71211c8f604 +Create Date: 2024-01-10 04:40:57.257824 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '114eed84c228' +down_revision = 'c71211c8f604' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.drop_column('tool_id') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False)) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py new file mode 100644 index 0000000000..25363ca947 --- /dev/null +++ b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py @@ -0,0 +1,32 @@ +"""add message files into agent thought + +Revision ID: 23db93619b9d +Revises: 8ae9bc661daa +Create Date: 2024-01-18 08:46:37.302657 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '23db93619b9d' +down_revision = '8ae9bc661daa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_column('message_files') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py b/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py new file mode 100644 index 0000000000..4a4a497993 --- /dev/null +++ b/api/migrations/versions/3ef9b2b6bee6_add_assistant_app.py @@ -0,0 +1,67 @@ +"""add_assistant_app + +Revision ID: 3ef9b2b6bee6 +Revises: 89c7899ca936 +Create Date: 2024-01-05 15:26:25.117551 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '3ef9b2b6bee6' +down_revision = '89c7899ca936' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_api_providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=40), nullable=False), + sa.Column('schema', sa.Text(), nullable=False), + sa.Column('schema_type_str', sa.String(length=40), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('description_str', sa.Text(), nullable=False), + sa.Column('tools_str', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_api_provider_pkey') + ) + op.create_table('tool_builtin_providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=True), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('provider', sa.String(length=40), nullable=False), + sa.Column('encrypted_credentials', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_builtin_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_tool_provider') + ) + op.create_table('tool_published_apps', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('llm_description', sa.Text(), nullable=False), + sa.Column('query_description', sa.Text(), nullable=False), + sa.Column('query_name', sa.String(length=40), nullable=False), + sa.Column('tool_name', sa.String(length=40), nullable=False), + sa.Column('author', sa.String(length=40), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.ForeignKeyConstraint(['app_id'], ['apps.id'], ), + sa.PrimaryKeyConstraint('id', name='published_app_tool_pkey'), + sa.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_published_apps') + op.drop_table('tool_builtin_providers') + op.drop_table('tool_api_providers') + # ### end Alembic commands ### diff --git a/api/migrations/versions/4823da1d26cf_add_tool_file.py b/api/migrations/versions/4823da1d26cf_add_tool_file.py new file mode 100644 index 0000000000..797a9539b7 --- /dev/null +++ b/api/migrations/versions/4823da1d26cf_add_tool_file.py @@ -0,0 +1,37 @@ +"""add tool file + +Revision ID: 4823da1d26cf +Revises: 053da0c1d756 +Create Date: 2024-01-15 11:37:16.782718 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '4823da1d26cf' +down_revision = '053da0c1d756' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_files', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('file_key', sa.String(length=255), nullable=False), + sa.Column('mimetype', sa.String(length=255), nullable=False), + sa.Column('original_url', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('id', name='tool_file_pkey') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_files') + # ### end Alembic commands ### diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py new file mode 100644 index 0000000000..f67a18cb2b --- /dev/null +++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py @@ -0,0 +1,36 @@ +"""change message chain id to nullable + +Revision ID: 4829e54d2fee +Revises: 114eed84c228 +Create Date: 2024-01-12 03:42:27.362415 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '4829e54d2fee' +down_revision = '114eed84c228' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=postgresql.UUID(), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=postgresql.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py b/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py new file mode 100644 index 0000000000..c65372419e --- /dev/null +++ b/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py @@ -0,0 +1,34 @@ +"""add tool conversation variables idx + +Revision ID: 8ae9bc661daa +Revises: 9fafbd60eca1 +Create Date: 2024-01-15 14:22:03.597692 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '8ae9bc661daa' +down_revision = '9fafbd60eca1' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_conversation_variables', schema=None) as batch_op: + batch_op.create_index('conversation_id_idx', ['conversation_id'], unique=False) + batch_op.create_index('user_id_idx', ['user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_conversation_variables', schema=None) as batch_op: + batch_op.drop_index('user_id_idx') + batch_op.drop_index('conversation_id_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py new file mode 100644 index 0000000000..1380512f30 --- /dev/null +++ b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py @@ -0,0 +1,32 @@ +"""rename api provider credentails + +Revision ID: 8ec536f3c800 +Revises: ad472b61a054 +Create Date: 2024-01-07 03:57:35.257545 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '8ec536f3c800' +down_revision = 'ad472b61a054' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_column('credentials_str') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py b/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py new file mode 100644 index 0000000000..367c2e731f --- /dev/null +++ b/api/migrations/versions/9fafbd60eca1_add_message_file_belongs_to.py @@ -0,0 +1,32 @@ +"""add message file belongs to + +Revision ID: 9fafbd60eca1 +Revises: 4823da1d26cf +Create Date: 2024-01-15 13:07:20.340896 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9fafbd60eca1' +down_revision = '4823da1d26cf' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('belongs_to', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('message_files', schema=None) as batch_op: + batch_op.drop_column('belongs_to') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/ad472b61a054_add_api_provider_icon.py b/api/migrations/versions/ad472b61a054_add_api_provider_icon.py new file mode 100644 index 0000000000..1328326c2d --- /dev/null +++ b/api/migrations/versions/ad472b61a054_add_api_provider_icon.py @@ -0,0 +1,32 @@ +"""add api provider icon + +Revision ID: ad472b61a054 +Revises: 3ef9b2b6bee6 +Create Date: 2024-01-07 02:21:23.114790 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'ad472b61a054' +down_revision = '3ef9b2b6bee6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon', sa.String(length=256), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_column('icon') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/c71211c8f604_add_tool_invoke_model_log.py b/api/migrations/versions/c71211c8f604_add_tool_invoke_model_log.py new file mode 100644 index 0000000000..dc96672b5e --- /dev/null +++ b/api/migrations/versions/c71211c8f604_add_tool_invoke_model_log.py @@ -0,0 +1,49 @@ +"""add tool_invoke_model_log + +Revision ID: c71211c8f604 +Revises: f25003750af4 +Create Date: 2024-01-09 11:42:50.664797 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'c71211c8f604' +down_revision = 'f25003750af4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_model_invokes', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('user_id', postgresql.UUID(), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider', sa.String(length=40), nullable=False), + sa.Column('tool_type', sa.String(length=40), nullable=False), + sa.Column('tool_name', sa.String(length=40), nullable=False), + sa.Column('tool_id', postgresql.UUID(), nullable=False), + sa.Column('model_parameters', sa.Text(), nullable=False), + sa.Column('prompt_messages', sa.Text(), nullable=False), + sa.Column('model_response', sa.Text(), nullable=False), + sa.Column('prompt_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('answer_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('answer_unit_price', sa.Numeric(precision=10, scale=4), nullable=False), + sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False), + sa.Column('provider_response_latency', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), + sa.Column('currency', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_model_invokes') + # ### end Alembic commands ### diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py new file mode 100644 index 0000000000..cbb3662851 --- /dev/null +++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py @@ -0,0 +1,109 @@ +"""migration serpapi_api_key + +Revision ID: de95f5c77138 +Revises: 23db93619b9d +Create Date: 2024-01-21 12:09:04.651394 + +""" +from alembic import op +import sqlalchemy as sa +from json import dumps, loads + + +# revision identifiers, used by Alembic. +revision = 'de95f5c77138' +down_revision = '23db93619b9d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + """ + 1. select all tool_providers + 2. insert api_key to tool_provider_configs + + tool_providers + - id + - tenant_id + - tool_name + - encrypted_credentials + {"api_key": "$KEY"} + - created_at + - updated_at + + tool_builtin_providers + - id <- tool_providers.id + - tenant_id <- tool_providers.tenant_id + - user_id <- tenant_account_joins.account_id (tenant_account_joins.tenant_id = tool_providers.tenant_id and tenant_account_joins.role = 'owner') + - encrypted_credentials <- tool_providers.encrypted_credentials + {"serpapi_api_key": "$KEY"} + - created_at <- tool_providers.created_at + - updated_at <- tool_providers.updated_at + + """ + # select all tool_providers + tool_providers = op.get_bind().execute( + sa.text( + "SELECT * FROM tool_providers WHERE tool_name = 'serpapi'" + ) + ).fetchall() + + # insert api_key to tool_provider_configs + for tool_provider in tool_providers: + id = tool_provider['id'] + tenant_id = tool_provider['tenant_id'] + encrypted_credentials = tool_provider['encrypted_credentials'] + + try: + credentials = loads(encrypted_credentials) + api_key = credentials['api_key'] + credentials['serpapi_api_key'] = api_key + credentials.pop('api_key') + encrypted_credentials = dumps(credentials) + except Exception as e: + print(e) + continue + + # get user_id + user_id = op.get_bind().execute( + sa.text( + "SELECT account_id FROM tenant_account_joins WHERE tenant_id = :tenant_id AND role = 'owner'" + ), + tenant_id=tenant_id + ).fetchone()['account_id'] + + created_at = tool_provider['created_at'] + updated_at = tool_provider['updated_at'] + + # insert to tool_builtin_providers + # check if exists + exists = op.get_bind().execute( + sa.text( + "SELECT * FROM tool_builtin_providers WHERE tenant_id = :tenant_id AND provider = 'google'" + ), + tenant_id=tenant_id + ).fetchone() + if exists: + continue + + op.get_bind().execute( + sa.text( + "INSERT INTO tool_builtin_providers (id, tenant_id, user_id, provider, encrypted_credentials, created_at, updated_at) VALUES (:id, :tenant_id, :user_id, :provider, :encrypted_credentials, :created_at, :updated_at)" + ), + id=id, + tenant_id=tenant_id, + user_id=user_id, + provider='google', + encrypted_credentials=encrypted_credentials, + created_at=created_at, + updated_at=updated_at + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/api/migrations/versions/f25003750af4_add_created_updated_at.py b/api/migrations/versions/f25003750af4_add_created_updated_at.py new file mode 100644 index 0000000000..8bdb6a0ff7 --- /dev/null +++ b/api/migrations/versions/f25003750af4_add_created_updated_at.py @@ -0,0 +1,34 @@ +"""add created/updated at + +Revision ID: f25003750af4 +Revises: 00bacef91f18 +Create Date: 2024-01-07 04:53:24.441861 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f25003750af4' +down_revision = '00bacef91f18' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False)) + batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_column('updated_at') + batch_op.drop_column('created_at') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index dab3bc6f84..f317113e8d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,11 +1,13 @@ import json +import uuid from core.file.upload_file_parser import UploadFileParser +from core.file.tool_file_parser import ToolFileParser from extensions.ext_database import db from flask import current_app, request from flask_login import UserMixin from libs.helper import generate_string -from sqlalchemy import Float +from sqlalchemy import Float, text from sqlalchemy.dialects.postgresql import UUID from .account import Account, Tenant @@ -66,7 +68,65 @@ class App(db.Model): def tenant(self): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant + + @property + def is_agent(self) -> bool: + app_model_config = self.app_model_config + if not app_model_config: + return False + if not app_model_config.agent_mode: + return False + if self.app_model_config.agent_mode_dict.get('enabled', False) \ + and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: + return True + return False + + @property + def deleted_tools(self) -> list: + # get agent mode tools + app_model_config = self.app_model_config + if not app_model_config: + return [] + if not app_model_config.agent_mode: + return [] + agent_mode = app_model_config.agent_mode_dict + tools = agent_mode.get('tools', []) + + provider_ids = [] + for tool in tools: + keys = list(tool.keys()) + if len(keys) >= 4: + provider_type = tool.get('provider_type', '') + provider_id = tool.get('provider_id', '') + if provider_type == 'api': + # check if provider id is a uuid string, if not, skip + try: + uuid.UUID(provider_id) + except Exception: + continue + provider_ids.append(provider_id) + + if not provider_ids: + return [] + + api_providers = db.session.execute( + text('SELECT id FROM tool_api_providers WHERE id IN :provider_ids'), + {'provider_ids': tuple(provider_ids)} + ).fetchall() + + deleted_tools = [] + current_api_provider_ids = [str(api_provider.id) for api_provider in api_providers] + + for tool in tools: + keys = list(tool.keys()) + if len(keys) >= 4: + provider_type = tool.get('provider_type', '') + provider_id = tool.get('provider_id', '') + if provider_type == 'api' and provider_id not in current_api_provider_ids: + deleted_tools.append(tool['tool_name']) + + return deleted_tools class AppModelConfig(db.Model): __tablename__ = 'app_model_configs' @@ -168,7 +228,7 @@ class AppModelConfig(db.Model): @property def agent_mode_dict(self) -> dict: - return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []} + return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], "prompt": None} @property def chat_prompt_config_dict(self) -> dict: @@ -337,6 +397,12 @@ class InstalledApp(db.Model): tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return tenant + @property + def is_agent(self) -> bool: + app = self.app + if not app: + return False + return app.is_agent class Conversation(db.Model): __tablename__ = 'conversations' @@ -582,11 +648,22 @@ class Message(db.Model): upload_file=upload_file, force_url=True ) + if message_file.transfer_method == 'tool_file': + # get extension + if '.' in message_file.url: + extension = f'.{message_file.url.split(".")[-1]}' + if len(extension) > 10: + extension = '.bin' + else: + extension = '.bin' + # add sign url + url = ToolFileParser.get_tool_file_manager().sign_file(file_id=message_file.id, extension=extension) files.append({ 'id': message_file.id, 'type': message_file.type, - 'url': url + 'url': url, + 'belongs_to': message_file.belongs_to if message_file.belongs_to else 'user' }) return files @@ -632,12 +709,12 @@ class MessageFile(db.Model): type = db.Column(db.String(255), nullable=False) transfer_method = db.Column(db.String(255), nullable=False) url = db.Column(db.Text, nullable=True) + belongs_to = db.Column(db.String(255), nullable=True) upload_file_id = db.Column(UUID, nullable=True) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - class MessageAnnotation(db.Model): __tablename__ = 'message_annotations' __table_args__ = ( @@ -912,7 +989,7 @@ class MessageAgentThought(db.Model): id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) message_id = db.Column(UUID, nullable=False) - message_chain_id = db.Column(UUID, nullable=False) + message_chain_id = db.Column(UUID, nullable=True) position = db.Column(db.Integer, nullable=False) thought = db.Column(db.Text, nullable=True) tool = db.Column(db.Text, nullable=True) @@ -924,6 +1001,7 @@ class MessageAgentThought(db.Model): message_token = db.Column(db.Integer, nullable=True) message_unit_price = db.Column(db.Numeric, nullable=True) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + message_files = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) answer_token = db.Column(db.Integer, nullable=True) answer_unit_price = db.Column(db.Numeric, nullable=True) @@ -936,6 +1014,12 @@ class MessageAgentThought(db.Model): created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + @property + def files(self) -> list: + if self.message_files: + return json.loads(self.message_files) + else: + return [] class DatasetRetrieverResource(db.Model): __tablename__ = 'dataset_retriever_resources' diff --git a/api/models/tools.py b/api/models/tools.py new file mode 100644 index 0000000000..74da1cd941 --- /dev/null +++ b/api/models/tools.py @@ -0,0 +1,227 @@ +import json +from enum import Enum +from typing import List + +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy import ForeignKey + +from extensions.ext_database import db + +from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolRuntimeVariablePool + +from models.model import Tenant, Account, App + +class BuiltinToolProvider(db.Model): + """ + This table stores the tool provider information for built-in tools for each tenant. + """ + __tablename__ = 'tool_builtin_providers' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_builtin_provider_pkey'), + # one tenant can only have one tool provider with the same name + db.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_tool_provider') + ) + + # id of the tool provider + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + # id of the tenant + tenant_id = db.Column(UUID, nullable=True) + # who created this tool provider + user_id = db.Column(UUID, nullable=False) + # name of the tool provider + provider = db.Column(db.String(40), nullable=False) + # credential of the tool provider + encrypted_credentials = db.Column(db.Text, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def credentials(self) -> dict: + return json.loads(self.encrypted_credentials) + +class PublishedAppTool(db.Model): + """ + The table stores the apps published as a tool for each person. + """ + __tablename__ = 'tool_published_apps' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='published_app_tool_pkey'), + db.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool') + ) + + # id of the tool provider + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + # id of the app + app_id = db.Column(UUID, ForeignKey('apps.id'), nullable=False) + # who published this tool + user_id = db.Column(UUID, nullable=False) + # description of the tool, stored in i18n format, for human + description = db.Column(db.Text, nullable=False) + # llm_description of the tool, for LLM + llm_description = db.Column(db.Text, nullable=False) + # query decription, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field + query_description = db.Column(db.Text, nullable=False) + # query name, the name of the query parameter + query_name = db.Column(db.String(40), nullable=False) + # name of the tool provider + tool_name = db.Column(db.String(40), nullable=False) + # author + author = db.Column(db.String(40), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def description_i18n(self) -> I18nObject: + return I18nObject(**json.loads(self.description)) + + @property + def app(self) -> App: + return db.session.query(App).filter(App.id == self.app_id).first() + +class ApiToolProvider(db.Model): + """ + The table stores the api providers. + """ + __tablename__ = 'tool_api_providers' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_api_provider_pkey'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + # name of the api provider + name = db.Column(db.String(40), nullable=False) + # icon + icon = db.Column(db.String(255), nullable=False) + # original schema + schema = db.Column(db.Text, nullable=False) + schema_type_str = db.Column(db.String(40), nullable=False) + # who created this tool + user_id = db.Column(UUID, nullable=False) + # tanent id + tenant_id = db.Column(UUID, nullable=False) + # description of the provider + description = db.Column(db.Text, nullable=False) + # json format tools + tools_str = db.Column(db.Text, nullable=False) + # json format credentials + credentials_str = db.Column(db.Text, nullable=False) + # privacy policy + privacy_policy = db.Column(db.String(255), nullable=True) + + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def schema_type(self) -> ApiProviderSchemaType: + return ApiProviderSchemaType.value_of(self.schema_type_str) + + @property + def tools(self) -> List[ApiBasedToolBundle]: + return [ApiBasedToolBundle(**tool) for tool in json.loads(self.tools_str)] + + @property + def credentials(self) -> dict: + return json.loads(self.credentials_str) + + @property + def is_taned(self) -> bool: + return self.tenant_id is not None + + @property + def user(self) -> Account: + return db.session.query(Account).filter(Account.id == self.user_id).first() + + @property + def tanent(self) -> Tenant: + return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + +class ToolModelInvoke(db.Model): + """ + store the invoke logs from tool invoke + """ + __tablename__ = "tool_model_invokes" + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + # who invoke this tool + user_id = db.Column(UUID, nullable=False) + # tanent id + tenant_id = db.Column(UUID, nullable=False) + # provider + provider = db.Column(db.String(40), nullable=False) + # type + tool_type = db.Column(db.String(40), nullable=False) + # tool name + tool_name = db.Column(db.String(40), nullable=False) + # invoke parameters + model_parameters = db.Column(db.Text, nullable=False) + # prompt messages + prompt_messages = db.Column(db.Text, nullable=False) + # invoke response + model_response = db.Column(db.Text, nullable=False) + + prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + total_price = db.Column(db.Numeric(10, 7)) + currency = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + +class ToolConversationVariables(db.Model): + """ + store the conversation variables from tool invoke + """ + __tablename__ = "tool_conversation_variables" + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey'), + # add index for user_id and conversation_id + db.Index('user_id_idx', 'user_id'), + db.Index('conversation_id_idx', 'conversation_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + # conversation user id + user_id = db.Column(UUID, nullable=False) + # tanent id + tenant_id = db.Column(UUID, nullable=False) + # conversation id + conversation_id = db.Column(UUID, nullable=False) + # variables pool + variables_str = db.Column(db.Text, nullable=False) + + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def variables(self) -> dict: + return json.loads(self.variables_str) + +class ToolFile(db.Model): + """ + store the file created by agent + """ + __tablename__ = "tool_files" + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tool_file_pkey'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + # conversation user id + user_id = db.Column(UUID, nullable=False) + # tanent id + tenant_id = db.Column(UUID, nullable=False) + # conversation id + conversation_id = db.Column(UUID, nullable=False) + # file key + file_key = db.Column(db.String(255), nullable=False) + # mime type + mimetype = db.Column(db.String(255), nullable=False) + # original url + original_url = db.Column(db.String(255), nullable=True) \ No newline at end of file diff --git a/api/requirements.txt b/api/requirements.txt index 9e62f9bd75..639941e39b 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -61,4 +61,7 @@ unstructured[docx,pptx,msg,md,ppt]~=0.10.27 bs4~=0.0.1 markdown~=3.5.1 google-generativeai~=0.3.2 -httpx[socks]~=0.24.1 \ No newline at end of file +httpx[socks]~=0.24.1 +pydub~=0.25.1 +matplotlib~=3.8.2 +yfinance~=0.2.35 diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index bf7dfab747..f4e697f356 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -185,7 +185,7 @@ class AppModelConfigService: variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph"]: + if key not in ["text-input", "select", "paragraph", "external_data_tool"]: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] @@ -262,28 +262,39 @@ class AppModelConfigService: for tool in config["agent_mode"]["tools"]: key = list(tool.keys())[0] - if key not in SUPPORT_TOOLS: - raise ValueError("Keys in agent_mode.tools must be in the specified tool list") + if key in SUPPORT_TOOLS: + # old style, use tool name as key + tool_item = tool[key] - tool_item = tool[key] + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False - if "enabled" not in tool_item or not tool_item["enabled"]: - tool_item["enabled"] = False + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") - if not isinstance(tool_item["enabled"], bool): - raise ValueError("enabled in agent_mode.tools must be of boolean type") + if key == "dataset": + if 'id' not in tool_item: + raise ValueError("id is required in dataset") - if key == "dataset": - if 'id' not in tool_item: - raise ValueError("id is required in dataset") + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") - try: - uuid.UUID(tool_item["id"]) - except ValueError: - raise ValueError("id in dataset must be of UUID type") - - if not cls.is_dataset_exists(account, tool_item["id"]): - raise ValueError("Dataset ID does not exist, please check your permission.") + if not cls.is_dataset_exists(account, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + else: + # latest style, use key-value pair + if "enabled" not in tool or not tool["enabled"]: + tool["enabled"] = False + if "provider_type" not in tool: + raise ValueError("provider_type is required in agent_mode.tools") + if "provider_id" not in tool: + raise ValueError("provider_id is required in agent_mode.tools") + if "tool_name" not in tool: + raise ValueError("tool_name is required in agent_mode.tools") + if "tool_parameters" not in tool: + raise ValueError("tool_parameters is required in agent_mode.tools") # dataset_query_variable cls.is_dataset_query_variable_valid(config, app_mode) @@ -454,6 +465,12 @@ class AppModelConfigService: if 'dataset_configs' not in config or not config["dataset_configs"]: config["dataset_configs"] = {'retrieval_model': 'single'} + if 'datasets' not in config["dataset_configs"] or not config["dataset_configs"]["datasets"]: + config["dataset_configs"]["datasets"] = { + "strategy": "router", + "datasets": [] + } + if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") diff --git a/api/services/completion_service.py b/api/services/completion_service.py index b218839df3..6035eb1b50 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -227,6 +227,8 @@ class CompletionService: input_type = list(config.keys())[0] if variable not in user_inputs or not user_inputs[variable]: + if input_type == "external_data_tool": + continue if "required" in input_config and input_config["required"]: raise ValueError(f"{variable} is required in input form") else: diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py new file mode 100644 index 0000000000..ee530b1c1a --- /dev/null +++ b/api/services/tools_manage_service.py @@ -0,0 +1,523 @@ +from typing import List, Tuple + +from flask import current_app + +from core.tools.tool_manager import ToolManager +from core.tools.entities.user_entities import UserToolProvider, UserTool +from core.tools.entities.tool_entities import ApiProviderSchemaType, ApiProviderAuthType, ToolProviderCredentials, \ + ToolCredentialsOption +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.api_tool_provider import ApiBasedToolProviderController +from core.tools.utils.parser import ApiBasedToolSchemaParser +from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict +from core.tools.utils.configration import ToolConfiguration +from core.tools.errors import ToolProviderCredentialValidationError, ToolProviderNotFoundError, ToolNotFoundError + +from extensions.ext_database import db +from models.tools import BuiltinToolProvider, ApiToolProvider + +from httpx import get + +import json + +class ToolManageService: + @staticmethod + def list_tool_providers(user_id: str, tanent_id: str): + """ + list tool providers + + :return: the list of tool providers + """ + result = [provider.to_dict() for provider in ToolManager.user_list_providers( + user_id, tanent_id + )] + + # add icon url prefix + for provider in result: + ToolManageService.repacket_provider(provider) + + return result + + @staticmethod + def repacket_provider(provider: dict): + """ + repacket provider + + :param provider: the provider dict + """ + url_prefix = (current_app.config.get("CONSOLE_API_URL") + + f"/console/api/workspaces/current/tool-provider/builtin/") + + if 'icon' in provider: + if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value: + provider['icon'] = url_prefix + provider['name'] + '/icon' + elif provider['type'] == UserToolProvider.ProviderType.API.value: + try: + provider['icon'] = json.loads(provider['icon']) + except: + provider['icon'] = { + "background": "#252525", + "content": "\ud83d\ude01" + } + + @staticmethod + def list_builtin_tool_provider_tools( + user_id: str, tenant_id: str, provider: str + ): + """ + list builtin tool provider tools + """ + provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) + tools = provider_controller.get_tools() + + result = [ + UserTool( + author=tool.identity.author, + name=tool.identity.name, + label=tool.identity.label, + description=tool.description.human, + parameters=tool.parameters or [] + ) for tool in tools + ] + + return json.loads( + serialize_base_model_array(result) + ) + + @staticmethod + def list_builtin_provider_credentials_schema( + provider_name + ): + """ + list builtin provider credentials schema + + :return: the list of tool providers + """ + provider = ToolManager.get_builtin_provider(provider_name) + return [ + v.to_dict() for _, v in (provider.credentials_schema or {}).items() + ] + + @staticmethod + def parser_api_schema(schema: str) -> List[ApiBasedToolBundle]: + """ + parse api schema to tool bundle + """ + try: + warnings = {} + try: + tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) + except Exception as e: + raise ValueError(f'invalid schema: {str(e)}') + + credentails_schema = [ + ToolProviderCredentials( + name='auth_type', + type=ToolProviderCredentials.CredentialsType.SELECT, + required=True, + default='none', + options=[ + ToolCredentialsOption(value='none', label=I18nObject( + en_US='None', + zh_Hans='无' + )), + ToolCredentialsOption(value='api_key', label=I18nObject( + en_US='Api Key', + zh_Hans='Api Key' + )), + ], + placeholder=I18nObject( + en_US='Select auth type', + zh_Hans='选择认证方式' + ) + ), + ToolProviderCredentials( + name='api_key_header', + type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, + required=False, + placeholder=I18nObject( + en_US='Enter api key header', + zh_Hans='输入 api key header,如:X-API-KEY' + ), + default='api_key', + help=I18nObject( + en_US='HTTP header name for api key', + zh_Hans='HTTP 头部字段名,用于传递 api key' + ) + ), + ToolProviderCredentials( + name='api_key_value', + type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, + required=False, + placeholder=I18nObject( + en_US='Enter api key', + zh_Hans='输入 api key' + ), + default='' + ), + ] + + return json.loads(serialize_base_model_dict( + { + 'schema_type': schema_type, + 'parameters_schema': tool_bundles, + 'credentials_schema': credentails_schema, + 'warning': warnings + } + )) + except Exception as e: + raise ValueError(f'invalid schema: {str(e)}') + + @staticmethod + def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> List[ApiBasedToolBundle]: + """ + convert schema to tool bundles + + :return: the list of tool bundles, description + """ + try: + tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) + return tool_bundles + except Exception as e: + raise ValueError(f'invalid schema: {str(e)}') + + @staticmethod + def create_api_tool_provider( + user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, + schema_type: str, schema: str, privacy_policy: str + ): + """ + create api tool provider + """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: + raise ValueError(f'invalid schema type {schema}') + + # check if the provider exists + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ).first() + + if provider is not None: + raise ValueError(f'provider {provider_name} already exists') + + # parse openapi to tool bundle + extra_info = {} + # extra info like description will be set here + tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + + if len(tool_bundles) > 10: + raise ValueError(f'the number of apis should be less than 10') + + # create db provider + db_provider = ApiToolProvider( + tenant_id=tenant_id, + user_id=user_id, + name=provider_name, + icon=json.dumps(icon), + schema=schema, + description=extra_info.get('description', ''), + schema_type_str=schema_type, + tools_str=serialize_base_model_array(tool_bundles), + credentials_str={}, + privacy_policy=privacy_policy + ) + + if 'auth_type' not in credentials: + raise ValueError('auth_type is required') + + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + + # create provider entity + provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) + + # encrypt credentials + tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) + db_provider.credentials_str = json.dumps(encrypted_credentials) + + db.session.add(db_provider) + db.session.commit() + + return { 'result': 'success' } + + @staticmethod + def get_api_tool_provider_remote_schema( + user_id: str, tenant_id: str, url: str + ): + """ + get api tool provider remote schema + """ + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", + "Accept": "*/*", + } + + try: + response = get(url, headers=headers, timeout=10) + if response.status_code != 200: + raise ValueError(f'Got status code {response.status_code}') + schema = response.text + + # try to parse schema, avoid SSRF attack + ToolManageService.parser_api_schema(schema) + except Exception as e: + raise ValueError(f'invalid schema, please check the url you provided') + + return { + 'schema': schema + } + + @staticmethod + def list_api_tool_provider_tools( + user_id: str, tenant_id: str, provider: str + ): + """ + list api tool provider tools + """ + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ).first() + + if provider is None: + raise ValueError(f'yout have not added provider {provider}') + + return json.loads( + serialize_base_model_array([ + UserTool( + author=tool_bundle.author, + name=tool_bundle.operation_id, + label=I18nObject( + en_US=tool_bundle.operation_id, + zh_Hans=tool_bundle.operation_id + ), + description=I18nObject( + en_US=tool_bundle.summary or '', + zh_Hans=tool_bundle.summary or '' + ), + parameters=tool_bundle.parameters + ) for tool_bundle in provider.tools + ]) + ) + + @staticmethod + def update_builtin_tool_provider( + user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): + """ + update builtin tool provider + """ + try: + # get provider + provider_controller = ToolManager.get_builtin_provider(provider_name) + if not provider_controller.need_credentials: + raise ValueError(f'provider {provider_name} does not need credentials') + # validate credentials + provider_controller.validate_credentials(credentials) + # encrypt credentials + tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + credentials = tool_configuration.encrypt_tool_credentials(credentials) + except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e: + raise ValueError(str(e)) + + # get if the provider exists + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ).first() + + if provider is None: + # create provider + provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + encrypted_credentials=json.dumps(credentials), + ) + + db.session.add(provider) + db.session.commit() + + else: + provider.encrypted_credentials = json.dumps(credentials) + + db.session.add(provider) + db.session.commit() + + return { 'result': 'success' } + + @staticmethod + def update_api_tool_provider( + user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: str, credentials: dict, + schema_type: str, schema: str, privacy_policy: str + ): + """ + update api tool provider + """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: + raise ValueError(f'invalid schema type {schema}') + + # check if the provider exists + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ).first() + + if provider is None: + raise ValueError(f'api provider {provider_name} does not exists') + + # parse openapi to tool bundle + extra_info = {} + # extra info like description will be set here + tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + + # update db provider + provider.name = provider_name + provider.icon = icon + provider.schema = schema + provider.description = extra_info.get('description', '') + provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value + provider.tools_str = serialize_base_model_array(tool_bundles) + provider.credentials_str = json.dumps(credentials) + provider.privacy_policy = privacy_policy + + if 'auth_type' not in credentials: + raise ValueError('auth_type is required') + + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + + # create provider entity + provider_entity = ApiBasedToolProviderController.from_db(provider, auth_type) + # load tools into provider entity + provider_entity.load_bundled_tools(tool_bundles) + + db.session.add(provider) + db.session.commit() + + return { 'result': 'success' } + + @staticmethod + def delete_builtin_tool_provider( + user_id: str, tenant_id: str, provider: str + ): + """ + delete tool provider + """ + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ).first() + + if provider is None: + raise ValueError(f'yout have not added provider {provider}') + + db.session.delete(provider) + db.session.commit() + + return { 'result': 'success' } + + @staticmethod + def get_builtin_tool_provider_icon( + provider: str + ): + """ + get tool provider icon and it's minetype + """ + icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) + with open(icon_path, 'rb') as f: + icon_bytes = f.read() + + return icon_bytes, mime_type + + @staticmethod + def delete_api_tool_provider( + user_id: str, tenant_id: str, provider: str + ): + """ + delete tool provider + """ + provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ).first() + + if provider is None: + raise ValueError(f'yout have not added provider {provider}') + + db.session.delete(provider) + db.session.commit() + + return { 'result': 'success' } + + @staticmethod + def get_api_tool_provider( + user_id: str, tenant_id: str, provider: str + ): + """ + get api tool provider + """ + return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) + + @staticmethod + def test_api_tool_preview( + tenant_id: str, tool_name: str, credentials: dict, parameters: dict, schema_type: str, schema: str + ): + """ + test api tool before adding api tool provider + + 1. parse schema into tool bundle + """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: + raise ValueError(f'invalid schema type {schema_type}') + + if schema_type == ApiProviderSchemaType.OPENAPI.value: + tool_bundles = ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(schema) + else: + raise ValueError(f'invalid schema type {schema_type}') + + # get tool bundle + tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) + if tool_bundle is None: + raise ValueError(f'invalid tool name {tool_name}') + + # create a fake db provider + db_provider = ApiToolProvider( + tenant_id='', user_id='', name='', icon='', + schema=schema, + description='', + schema_type_str=ApiProviderSchemaType.OPENAPI.value, + tools_str=serialize_base_model_array(tool_bundles), + credentials_str=json.dumps(credentials), + ) + + if 'auth_type' not in credentials: + raise ValueError('auth_type is required') + + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + + # create provider entity + provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) + + try: + provider_controller.validate_credentials_format(credentials) + # get tool + tool = provider_controller.get_tool(tool_name) + tool = tool.fork_tool_runtime(meta={ + 'credentials': credentials, + 'tenant_id': tenant_id, + }) + tool.validate_credentials(credentials, parameters) + except Exception as e: + return { 'error': str(e) } + + return { 'result': 'success' } \ No newline at end of file diff --git a/api/tests/integration_tests/.gitignore b/api/tests/integration_tests/.gitignore new file mode 100644 index 0000000000..426667562b --- /dev/null +++ b/api/tests/integration_tests/.gitignore @@ -0,0 +1 @@ +.env.test \ No newline at end of file diff --git a/api/tests/integration_tests/tools/__init__.py b/api/tests/integration_tests/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py new file mode 100644 index 0000000000..ba14d365c5 --- /dev/null +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -0,0 +1,38 @@ +from flask import Flask, request +from flask_restful import Api, Resource + +app = Flask(__name__) +api = Api(app) + +# Mock data +todos_data = { + "global": ["Buy groceries", "Finish project"], + "user1": ["Go for a run", "Read a book"], +} + +class TodosResource(Resource): + def get(self, username): + todos = todos_data.get(username, []) + return {"todos": todos} + + def post(self, username): + data = request.get_json() + new_todo = data.get("todo") + todos_data.setdefault(username, []).append(new_todo) + return {"message": "Todo added successfully"} + + def delete(self, username): + data = request.get_json() + todo_idx = data.get("todo_idx") + todos = todos_data.get(username, []) + + if 0 <= todo_idx < len(todos): + del todos[todo_idx] + return {"message": "Todo deleted successfully"} + + return {"error": "Invalid todo index"}, 400 + +api.add_resource(TodosResource, '/todos/') + +if __name__ == '__main__': + app.run(port=5003, debug=True) diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py new file mode 100644 index 0000000000..83eccb1b11 --- /dev/null +++ b/api/tests/integration_tests/tools/test_all_provider.py @@ -0,0 +1,9 @@ +from core.tools.tool_manager import ToolManager + +def test_tool_providers(): + """ + Test that all tool providers can be loaded + """ + providers = ToolManager.list_builtin_providers() + for provider in providers: + provider.get_tools()