From 642842d61b68422f80579703fc5d3aa0916f7c4b Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Sun, 10 Sep 2023 15:17:43 +0800 Subject: [PATCH] Feat:dataset retiever resource (#1123) Co-authored-by: jyong Co-authored-by: StyleZhang --- api/controllers/console/app/app.py | 1 + api/controllers/console/app/completion.py | 2 + api/controllers/console/explore/completion.py | 2 + api/controllers/console/explore/message.py | 20 +++++++ api/controllers/console/explore/parameter.py | 2 + .../console/universal_chat/chat.py | 2 + .../console/universal_chat/message.py | 20 +++++++ .../console/universal_chat/parameter.py | 5 ++ .../console/universal_chat/wraps.py | 1 + api/controllers/service_api/app/app.py | 2 + api/controllers/service_api/app/completion.py | 4 ++ api/controllers/service_api/app/message.py | 19 +++++++ api/controllers/web/app.py | 2 + api/controllers/web/completion.py | 4 ++ api/controllers/web/message.py | 20 +++++++ .../agent/agent/multi_dataset_router_agent.py | 5 ++ .../dataset_tool_callback_handler.py | 5 +- .../index_tool_callback_handler.py | 8 ++- api/core/completion.py | 10 ++-- api/core/conversation_message_task.py | 56 ++++++++++++++++++- .../keyword_table_index.py | 2 +- .../index/vector_index/qdrant_vector_index.py | 19 +++++++ .../model_providers/models/entity/message.py | 1 + api/core/orchestrator_rule_parser.py | 29 +++++++--- .../prompt/generate_prompts/common_chat.json | 2 +- api/core/prompt/prompts.py | 2 +- api/core/tool/dataset_retriever_tool.py | 52 +++++++++++++++-- ...43972bdc_add_dataset_retriever_resource.py | 54 ++++++++++++++++++ ...3755c_add_app_config_retriever_resource.py | 32 +++++++++++ api/models/model.py | 45 ++++++++++++++- api/services/app_model_config_service.py | 16 ++++++ api/services/completion_service.py | 31 ++++++++-- 32 files changed, 442 insertions(+), 33 deletions(-) create mode 100644 api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py create mode 100644 api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 8dff7f3abf..97b862c76a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -29,6 +29,7 @@ model_config_fields = { 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), + 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), 'more_like_this': fields.Raw(attribute='more_like_this_dict'), 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), 'model': fields.Raw(attribute='model_dict'), diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 3e22ca96df..02b3360a08 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -42,6 +42,7 @@ class CompletionMessageApi(Resource): parser.add_argument('query', type=str, location='json', default='') parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') args = parser.parse_args() streaming = args['response_mode'] != 'blocking' @@ -115,6 +116,7 @@ class ChatMessageApi(Resource): parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') args = parser.parse_args() streaming = args['response_mode'] != 'blocking' diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index b708367258..bdf1f3b907 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -33,6 +33,7 @@ class CompletionApi(InstalledAppResource): parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') args = parser.parse_args() streaming = args['response_mode'] == 'streaming' @@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource): parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') args = parser.parse_args() streaming = args['response_mode'] == 'streaming' diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 160ebee122..0349c23ef3 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -30,6 +30,25 @@ class MessageListApi(InstalledAppResource): 'rating': fields.String } + retriever_resource_fields = { + 'id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'dataset_id': fields.String, + 'dataset_name': fields.String, + 'document_id': fields.String, + 'document_name': fields.String, + 'data_source_type': fields.String, + 'segment_id': fields.String, + 'score': fields.Float, + 'hit_count': fields.Integer, + 'word_count': fields.Integer, + 'segment_position': fields.Integer, + 'index_node_hash': fields.String, + 'content': fields.String, + 'created_at': TimestampField + } + message_fields = { 'id': fields.String, 'conversation_id': fields.String, @@ -37,6 +56,7 @@ class MessageListApi(InstalledAppResource): 'query': fields.String, 'answer': fields.String, 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), 'created_at': TimestampField } diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 13e356cb26..fb4ce33209 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource): 'suggested_questions': fields.Raw, 'suggested_questions_after_answer': fields.Raw, 'speech_to_text': fields.Raw, + 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, } @@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource): 'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'speech_to_text': app_model_config.speech_to_text_dict, + 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list } diff --git a/api/controllers/console/universal_chat/chat.py b/api/controllers/console/universal_chat/chat.py index a6aa842042..0c7401a70a 100644 --- a/api/controllers/console/universal_chat/chat.py +++ b/api/controllers/console/universal_chat/chat.py @@ -29,9 +29,11 @@ class UniversalChatApi(UniversalChatResource): parser.add_argument('provider', type=str, required=True, location='json') parser.add_argument('model', type=str, required=True, location='json') parser.add_argument('tools', type=list, required=True, location='json') + parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json') args = parser.parse_args() app_model_config = app_model.app_model_config + app_model_config # update app model config args['model_config'] = app_model_config.to_dict() diff --git a/api/controllers/console/universal_chat/message.py b/api/controllers/console/universal_chat/message.py index 07d8b37fee..8568d3e9e9 100644 --- a/api/controllers/console/universal_chat/message.py +++ b/api/controllers/console/universal_chat/message.py @@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource): 'created_at': TimestampField } + retriever_resource_fields = { + 'id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'dataset_id': fields.String, + 'dataset_name': fields.String, + 'document_id': fields.String, + 'document_name': fields.String, + 'data_source_type': fields.String, + 'segment_id': fields.String, + 'score': fields.Float, + 'hit_count': fields.Integer, + 'word_count': fields.Integer, + 'segment_position': fields.Integer, + 'index_node_hash': fields.String, + 'content': fields.String, + 'created_at': TimestampField + } + message_fields = { 'id': fields.String, 'conversation_id': fields.String, @@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource): 'query': fields.String, 'answer': fields.String, 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), 'created_at': TimestampField, 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)) } diff --git a/api/controllers/console/universal_chat/parameter.py b/api/controllers/console/universal_chat/parameter.py index b492bba501..fb00ca12cf 100644 --- a/api/controllers/console/universal_chat/parameter.py +++ b/api/controllers/console/universal_chat/parameter.py @@ -1,4 +1,6 @@ # -*- coding:utf-8 -*- +import json + from flask_restful import marshal_with, fields from controllers.console import api @@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource): 'suggested_questions': fields.Raw, 'suggested_questions_after_answer': fields.Raw, 'speech_to_text': fields.Raw, + 'retriever_resource': fields.Raw, } @marshal_with(parameters_fields) @@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource): """Retrieve app parameters.""" app_model = universal_app app_model_config = app_model.app_model_config + app_model_config.retriever_resource = json.dumps({'enabled': True}) return { 'opening_statement': app_model_config.opening_statement, 'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'speech_to_text': app_model_config.speech_to_text_dict, + 'retriever_resource': app_model_config.retriever_resource_dict, } diff --git a/api/controllers/console/universal_chat/wraps.py b/api/controllers/console/universal_chat/wraps.py index d51b87a13a..8ed472d3b6 100644 --- a/api/controllers/console/universal_chat/wraps.py +++ b/api/controllers/console/universal_chat/wraps.py @@ -47,6 +47,7 @@ def universal_chat_app_required(view=None): suggested_questions=json.dumps([]), suggested_questions_after_answer=json.dumps({'enabled': True}), speech_to_text=json.dumps({'enabled': True}), + retriever_resource=json.dumps({'enabled': True}), more_like_this=None, sensitive_word_avoidance=None, model=json.dumps({ diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 481133367e..86b8642571 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource): 'suggested_questions': fields.Raw, 'suggested_questions_after_answer': fields.Raw, 'speech_to_text': fields.Raw, + 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, } @@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource): 'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'speech_to_text': app_model_config.speech_to_text_dict, + 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list } diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 8a441aa2fa..a339322ea8 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -30,6 +30,8 @@ class CompletionApi(AppApiResource): parser.add_argument('query', type=str, location='json', default='') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('user', type=str, location='json') + parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + args = parser.parse_args() streaming = args['response_mode'] == 'streaming' @@ -91,6 +93,8 @@ class ChatApi(AppApiResource): parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('user', type=str, location='json') + parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + args = parser.parse_args() streaming = args['response_mode'] == 'streaming' diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index ef020891ff..e482d16d4f 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -16,6 +16,24 @@ class MessageListApi(AppApiResource): feedback_fields = { 'rating': fields.String } + retriever_resource_fields = { + 'id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'dataset_id': fields.String, + 'dataset_name': fields.String, + 'document_id': fields.String, + 'document_name': fields.String, + 'data_source_type': fields.String, + 'segment_id': fields.String, + 'score': fields.Float, + 'hit_count': fields.Integer, + 'word_count': fields.Integer, + 'segment_position': fields.Integer, + 'index_node_hash': fields.String, + 'content': fields.String, + 'created_at': TimestampField + } message_fields = { 'id': fields.String, @@ -24,6 +42,7 @@ class MessageListApi(AppApiResource): 'query': fields.String, 'answer': fields.String, 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), 'created_at': TimestampField } diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index dd66707345..cffda04eea 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource): 'suggested_questions': fields.Raw, 'suggested_questions_after_answer': fields.Raw, 'speech_to_text': fields.Raw, + 'retriever_resource': fields.Raw, 'more_like_this': fields.Raw, 'user_input_form': fields.Raw, } @@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource): 'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'speech_to_text': app_model_config.speech_to_text_dict, + 'retriever_resource': app_model_config.retriever_resource_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list } diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 25744b61af..79c0c542d1 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -31,6 +31,8 @@ class CompletionApi(WebApiResource): parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json', default='') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + args = parser.parse_args() streaming = args['response_mode'] == 'streaming' @@ -88,6 +90,8 @@ class ChatApi(WebApiResource): parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + args = parser.parse_args() streaming = args['response_mode'] == 'streaming' diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index f25f1e5af9..9d083f0027 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -29,6 +29,25 @@ class MessageListApi(WebApiResource): 'rating': fields.String } + retriever_resource_fields = { + 'id': fields.String, + 'message_id': fields.String, + 'position': fields.Integer, + 'dataset_id': fields.String, + 'dataset_name': fields.String, + 'document_id': fields.String, + 'document_name': fields.String, + 'data_source_type': fields.String, + 'segment_id': fields.String, + 'score': fields.Float, + 'hit_count': fields.Integer, + 'word_count': fields.Integer, + 'segment_position': fields.Integer, + 'index_node_hash': fields.String, + 'content': fields.String, + 'created_at': TimestampField + } + message_fields = { 'id': fields.String, 'conversation_id': fields.String, @@ -36,6 +55,7 @@ class MessageListApi(WebApiResource): 'query': fields.String, 'answer': fields.String, 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), 'created_at': TimestampField } diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index ebb9531b26..c7767d1307 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -1,3 +1,4 @@ +import json from typing import Tuple, List, Any, Union, Sequence, Optional, cast from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent @@ -53,6 +54,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): tool = next(iter(self.tools)) tool = cast(DatasetRetrieverTool, tool) rst = tool.run(tool_input={'query': kwargs['input']}) + # output = '' + # rst_json = json.loads(rst) + # for item in rst_json: + # output += f'{item["content"]}\n' return AgentFinish(return_values={"output": rst}, log=rst) if intermediate_steps: diff --git a/api/core/callback_handler/dataset_tool_callback_handler.py b/api/core/callback_handler/dataset_tool_callback_handler.py index 7d2ba4de1f..b1f9b8c602 100644 --- a/api/core/callback_handler/dataset_tool_callback_handler.py +++ b/api/core/callback_handler/dataset_tool_callback_handler.py @@ -64,12 +64,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: - # kwargs={'name': 'Search'} - # llm_prefix='Thought:' - # observation_prefix='Observation: ' - # output='53 years' pass + def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 59518667a0..ec02bdae9e 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -2,6 +2,7 @@ from typing import List from langchain.schema import Document +from core.conversation_message_task import ConversationMessageTask from extensions.ext_database import db from models.dataset import DocumentSegment @@ -9,8 +10,9 @@ from models.dataset import DocumentSegment class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, dataset_id: str) -> None: + def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None: self.dataset_id = dataset_id + self.conversation_message_task = conversation_message_task def on_tool_end(self, documents: List[Document]) -> None: """Handle tool end.""" @@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler: ) db.session.commit() + + def return_retriever_resource_info(self, resource: List): + """Handle return_retriever_resource_info.""" + self.conversation_message_task.on_dataset_query_finish(resource) diff --git a/api/core/completion.py b/api/core/completion.py index 192b16f4f9..2635b77c5d 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -1,3 +1,4 @@ +import json import logging import re from typing import Optional, List, Union, Tuple @@ -19,13 +20,15 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT +from models.dataset import DocumentSegment, Dataset, Document from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser class Completion: @classmethod def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, - user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False): + user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, + is_override: bool = False, retriever_from: str = 'dev'): """ errors: ProviderTokenNotInitError """ @@ -96,7 +99,6 @@ class Completion: should_use_agent = agent_executor.should_use_agent(query) if should_use_agent: agent_execute_result = agent_executor.run(query) - # run the final llm try: cls.run_final_llm( @@ -118,7 +120,8 @@ class Completion: return @classmethod - def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, + def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, + inputs: dict, agent_execute_result: Optional[AgentExecuteResult], conversation_message_task: ConversationMessageTask, memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): @@ -150,7 +153,6 @@ class Completion: callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], fake_response=fake_response ) - return response @classmethod diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 75e01d6923..8b675d2144 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -1,6 +1,6 @@ import decimal import json -from typing import Optional, Union +from typing import Optional, Union, List from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.dataset_query import DatasetQueryObj @@ -15,7 +15,8 @@ from events.message_event import message_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DatasetQuery -from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain +from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \ + MessageChain, DatasetRetrieverResource class ConversationMessageTask: @@ -41,6 +42,8 @@ class ConversationMessageTask: self.message = None + self.retriever_resource = None + self.model_dict = self.app_model_config.model_dict self.provider_name = self.model_dict.get('provider') self.model_name = self.model_dict.get('name') @@ -157,7 +160,8 @@ class ConversationMessageTask: self.message.message_tokens = message_tokens self.message.message_unit_price = message_unit_price self.message.message_price_unit = message_price_unit - self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' + self.message.answer = PromptBuilder.process_template( + llm_message.completion.strip()) if llm_message.completion else '' self.message.answer_tokens = answer_tokens self.message.answer_unit_price = answer_unit_price self.message.answer_price_unit = answer_price_unit @@ -256,7 +260,36 @@ class ConversationMessageTask: db.session.add(dataset_query) + def on_dataset_query_finish(self, resource: List): + if resource and len(resource) > 0: + for item in resource: + dataset_retriever_resource = DatasetRetrieverResource( + message_id=self.message.id, + position=item.get('position'), + dataset_id=item.get('dataset_id'), + dataset_name=item.get('dataset_name'), + document_id=item.get('document_id'), + document_name=item.get('document_name'), + data_source_type=item.get('data_source_type'), + segment_id=item.get('segment_id'), + score=item.get('score') if 'score' in item else None, + hit_count=item.get('hit_count') if 'hit_count' else None, + word_count=item.get('word_count') if 'word_count' in item else None, + segment_position=item.get('segment_position') if 'segment_position' in item else None, + index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, + content=item.get('content'), + retriever_from=item.get('retriever_from'), + created_by=self.user.id + ) + db.session.add(dataset_retriever_resource) + db.session.flush() + self.retriever_resource = resource + + def message_end(self): + self._pub_handler.pub_message_end(self.retriever_resource) + def end(self): + self._pub_handler.pub_message_end(self.retriever_resource) self._pub_handler.pub_end() @@ -350,6 +383,23 @@ class PubHandler: self.pub_end() raise ConversationTaskStoppedException() + def pub_message_end(self, retriever_resource: List): + content = { + 'event': 'message_end', + 'data': { + 'task_id': self._task_id, + 'message_id': self._message.id, + 'mode': self._conversation.mode, + 'conversation_id': self._conversation.id + } + } + if retriever_resource: + content['data']['retriever_resources'] = retriever_resource + redis_client.publish(self._channel, json.dumps(content)) + + if self._is_stopped(): + self.pub_end() + raise ConversationTaskStoppedException() def pub_end(self): content = { diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index 3792120e45..7b00e9825f 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -74,7 +74,7 @@ class KeywordTableIndex(BaseIndex): DocumentSegment.document_id == document_id ).all() - ids = [segment.id for segment in segments] + ids = [segment.index_node_id for segment in segments] keyword_table = self._get_dataset_keyword_table() keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index 3de47a628b..4814837c8f 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -113,6 +113,25 @@ class QdrantVectorIndex(BaseVectorIndex): ], )) + def delete_by_ids(self, ids: list[str]) -> None: + if self._is_origin(): + self.recreate_dataset(self.dataset) + return + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + from qdrant_client.http import models + for node_id in ids: + vector_store.del_texts(models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchValue(value=node_id), + ), + ], + )) + def _is_origin(self): if self.dataset.index_struct_dict: class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] diff --git a/api/core/model_providers/models/entity/message.py b/api/core/model_providers/models/entity/message.py index f2fab9c4b7..921bdcf193 100644 --- a/api/core/model_providers/models/entity/message.py +++ b/api/core/model_providers/models/entity/message.py @@ -8,6 +8,7 @@ class LLMRunResult(BaseModel): content: str prompt_tokens: int completion_tokens: int + source: list = None class MessageType(enum.Enum): diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 310fb0ae15..cceb9db1a9 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -36,8 +36,8 @@ class OrchestratorRuleParser: self.app_model_config = app_model_config def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], - rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \ - -> Optional[AgentExecutor]: + rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, + return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]: if not self.app_model_config.agent_mode_dict: return None @@ -74,7 +74,7 @@ class OrchestratorRuleParser: # only OpenAI chat model (include Azure) support function call, use ReACT instead if agent_model_instance.model_mode != ModelMode.CHAT \ - or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: + or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: planning_strategy = PlanningStrategy.REACT elif planning_strategy == PlanningStrategy.ROUTER: @@ -99,7 +99,9 @@ class OrchestratorRuleParser: tool_configs=tool_configs, conversation_message_task=conversation_message_task, rest_tokens=rest_tokens, - callbacks=[agent_callback, DifyStdOutCallbackHandler()] + callbacks=[agent_callback, DifyStdOutCallbackHandler()], + return_resource=return_resource, + retriever_from=retriever_from ) if len(tools) == 0: @@ -145,8 +147,10 @@ class OrchestratorRuleParser: return None - def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask, - rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]: + def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, + conversation_message_task: ConversationMessageTask, + rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False, + retriever_from: str = 'dev') -> list[BaseTool]: """ Convert app agent tool configs to tools @@ -155,6 +159,8 @@ class OrchestratorRuleParser: :param tool_configs: app agent tool configs :param conversation_message_task: :param callbacks: + :param return_resource: + :param retriever_from: :return: """ tools = [] @@ -166,7 +172,7 @@ class OrchestratorRuleParser: tool = None if tool_type == "dataset": - tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens) + tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from) elif tool_type == "web_reader": tool = self.to_web_reader_tool(agent_model_instance) elif tool_type == "google_search": @@ -183,13 +189,15 @@ class OrchestratorRuleParser: return tools def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, - rest_tokens: int) \ + rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \ -> Optional[BaseTool]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param rest_tokens: :param tool_config: :param conversation_message_task: + :param return_resource: + :param retriever_from: :return: """ # get dataset from dataset id @@ -208,7 +216,10 @@ class OrchestratorRuleParser: tool = DatasetRetrieverTool.from_dataset( dataset=dataset, k=k, - callbacks=[DatasetToolCallbackHandler(conversation_message_task)] + callbacks=[DatasetToolCallbackHandler(conversation_message_task)], + conversation_message_task=conversation_message_task, + return_resource=return_resource, + retriever_from=retriever_from ) return tool diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/generate_prompts/common_chat.json index c817caf36c..709a8d8866 100644 --- a/api/core/prompt/generate_prompts/common_chat.json +++ b/api/core/prompt/generate_prompts/common_chat.json @@ -10,4 +10,4 @@ ], "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ", "stops": ["\nHuman:", ""] -} \ No newline at end of file +} diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index ca857d4dd7..979fe9be96 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -105,7 +105,7 @@ GENERATOR_QA_PROMPT = ( 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' 'Step 4: Generate 20 questions and answers based on these key information and concepts.' 'The questions should be clear and detailed, and the answers should be detailed and complete.\n' - "Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n" + "Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n" ) RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index 46deff497b..4c9c9b625d 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -1,3 +1,4 @@ +import json from typing import Type from flask import current_app @@ -5,13 +6,14 @@ from langchain.tools import BaseTool from pydantic import Field, BaseModel from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.conversation_message_task import ConversationMessageTask from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.vector_index.vector_index import VectorIndex from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment +from models.dataset import Dataset, DocumentSegment, Document class DatasetRetrieverToolInput(BaseModel): @@ -27,6 +29,10 @@ class DatasetRetrieverTool(BaseTool): tenant_id: str dataset_id: str k: int = 3 + conversation_message_task: ConversationMessageTask + return_resource: str + retriever_from: str + @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): @@ -86,7 +92,7 @@ class DatasetRetrieverTool(BaseTool): if self.k > 0: documents = vector_index.search( query, - search_type='similarity', + search_type='similarity_score_threshold', search_kwargs={ 'k': self.k } @@ -94,8 +100,12 @@ class DatasetRetrieverTool(BaseTool): else: documents = [] - hit_callback = DatasetIndexToolCallbackHandler(dataset.id) + hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task) hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] document_context_list = [] index_node_ids = [document.metadata['doc_id'] for document in documents] segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, @@ -112,9 +122,43 @@ class DatasetRetrieverTool(BaseTool): float('inf'))) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}') + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') else: document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + context = {} + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from + } + if dataset.indexing_technique != "economy": + source['score'] = document_score_list.get(segment.index_node_id) + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) diff --git a/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py b/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py new file mode 100644 index 0000000000..255dddeec6 --- /dev/null +++ b/api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py @@ -0,0 +1,54 @@ +"""add_dataset_retriever_resource + +Revision ID: 6dcb43972bdc +Revises: 4bcffcd64aa4 +Create Date: 2023-09-06 16:51:27.385844 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6dcb43972bdc' +down_revision = '4bcffcd64aa4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_retriever_resources', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('dataset_name', sa.Text(), nullable=False), + sa.Column('document_id', postgresql.UUID(), nullable=False), + sa.Column('document_name', sa.Text(), nullable=False), + sa.Column('data_source_type', sa.Text(), nullable=False), + sa.Column('segment_id', postgresql.UUID(), nullable=False), + sa.Column('score', sa.Float(), nullable=True), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('hit_count', sa.Integer(), nullable=True), + sa.Column('word_count', sa.Integer(), nullable=True), + sa.Column('segment_position', sa.Integer(), nullable=True), + sa.Column('index_node_hash', sa.Text(), nullable=True), + sa.Column('retriever_from', sa.Text(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey') + ) + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.create_index('dataset_retriever_resource_message_id_idx', ['message_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.drop_index('dataset_retriever_resource_message_id_idx') + + op.drop_table('dataset_retriever_resources') + # ### end Alembic commands ### diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py new file mode 100644 index 0000000000..405e9520e1 --- /dev/null +++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py @@ -0,0 +1,32 @@ +"""add_app_config_retriever_resource + +Revision ID: 77e83833755c +Revises: 6dcb43972bdc +Create Date: 2023-09-06 17:26:40.311927 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '77e83833755c' +down_revision = '6dcb43972bdc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('retriever_resource') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index b77363d0b8..9172410ebb 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,4 +1,5 @@ import json +from json import JSONDecodeError from flask import current_app, request from flask_login import UserMixin @@ -90,6 +91,7 @@ class AppModelConfig(db.Model): pre_prompt = db.Column(db.Text) agent_mode = db.Column(db.Text) sensitive_word_avoidance = db.Column(db.Text) + retriever_resource = db.Column(db.Text) @property def app(self): @@ -114,6 +116,11 @@ class AppModelConfig(db.Model): return json.loads(self.speech_to_text) if self.speech_to_text \ else {"enabled": False} + @property + def retriever_resource_dict(self) -> dict: + return json.loads(self.retriever_resource) if self.retriever_resource \ + else {"enabled": False} + @property def more_like_this_dict(self) -> dict: return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} @@ -140,6 +147,7 @@ class AppModelConfig(db.Model): "suggested_questions": self.suggested_questions_list, "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, "speech_to_text": self.speech_to_text_dict, + "retriever_resource": self.retriever_resource, "more_like_this": self.more_like_this_dict, "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, "model": self.model_dict, @@ -164,7 +172,8 @@ class AppModelConfig(db.Model): self.user_input_form = json.dumps(model_config['user_input_form']) self.pre_prompt = model_config['pre_prompt'] self.agent_mode = json.dumps(model_config['agent_mode']) - + self.retriever_resource = json.dumps(model_config['retriever_resource']) \ + if model_config.get('retriever_resource') else None return self def copy(self): @@ -318,6 +327,7 @@ class Conversation(db.Model): model_config['suggested_questions'] = app_model_config.suggested_questions_list model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict model_config['speech_to_text'] = app_model_config.speech_to_text_dict + model_config['retriever_resource'] = app_model_config.retriever_resource_dict model_config['more_like_this'] = app_model_config.more_like_this_dict model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict model_config['user_input_form'] = app_model_config.user_input_form_list @@ -476,6 +486,11 @@ class Message(db.Model): return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \ .order_by(MessageAgentThought.position.asc()).all() + @property + def retriever_resources(self): + return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \ + .order_by(DatasetRetrieverResource.position.asc()).all() + class MessageFeedback(db.Model): __tablename__ = 'message_feedbacks' @@ -719,3 +734,31 @@ class MessageAgentThought(db.Model): created_by_role = db.Column(db.String, nullable=False) created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + +class DatasetRetrieverResource(db.Model): + __tablename__ = 'dataset_retriever_resources' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'), + db.Index('dataset_retriever_resource_message_id_idx', 'message_id'), + ) + + id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + message_id = db.Column(UUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + dataset_id = db.Column(UUID, nullable=False) + dataset_name = db.Column(db.Text, nullable=False) + document_id = db.Column(UUID, nullable=False) + document_name = db.Column(db.Text, nullable=False) + data_source_type = db.Column(db.Text, nullable=False) + segment_id = db.Column(UUID, nullable=False) + score = db.Column(db.Float, nullable=True) + content = db.Column(db.Text, nullable=False) + hit_count = db.Column(db.Integer, nullable=True) + word_count = db.Column(db.Integer, nullable=True) + segment_position = db.Column(db.Integer, nullable=True) + index_node_hash = db.Column(db.Text, nullable=True) + retriever_from = db.Column(db.Text, nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 18ca399dbd..f402b4e2b4 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -130,6 +130,21 @@ class AppModelConfigService: if not isinstance(config["speech_to_text"]["enabled"], bool): raise ValueError("enabled in speech_to_text must be of boolean type") + # return retriever resource + if 'retriever_resource' not in config or not config["retriever_resource"]: + config["retriever_resource"] = { + "enabled": False + } + + if not isinstance(config["retriever_resource"], dict): + raise ValueError("retriever_resource must be of dict type") + + if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: + config["retriever_resource"]["enabled"] = False + + if not isinstance(config["retriever_resource"]["enabled"], bool): + raise ValueError("enabled in speech_to_text must be of boolean type") + # more_like_this if 'more_like_this' not in config or not config["more_like_this"]: config["more_like_this"] = { @@ -327,6 +342,7 @@ class AppModelConfigService: "suggested_questions": config["suggested_questions"], "suggested_questions_after_answer": config["suggested_questions_after_answer"], "speech_to_text": config["speech_to_text"], + "retriever_resource": config["retriever_resource"], "more_like_this": config["more_like_this"], "sensitive_word_avoidance": config["sensitive_word_avoidance"], "model": { diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 6311d5a01d..ce9d45b325 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -11,7 +11,8 @@ from sqlalchemy import and_ from core.completion import Completion from core.conversation_message_task import PubHandler, ConversationTaskStoppedException -from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ + LLMRateLimitError, \ LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -95,6 +96,7 @@ class CompletionService: app_model_config_model = app_model_config.model_dict app_model_config_model['completion_params'] = completion_params + app_model_config.retriever_resource = json.dumps({'enabled': True}) app_model_config = app_model_config.copy() app_model_config.model = json.dumps(app_model_config_model) @@ -145,7 +147,8 @@ class CompletionService: 'user': user, 'conversation': conversation, 'streaming': streaming, - 'is_model_config_override': is_model_config_override + 'is_model_config_override': is_model_config_override, + 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev' }) generate_worker_thread.start() @@ -169,7 +172,8 @@ class CompletionService: @classmethod def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig, query: str, inputs: dict, user: Union[Account, EndUser], - conversation: Conversation, streaming: bool, is_model_config_override: bool): + conversation: Conversation, streaming: bool, is_model_config_override: bool, + retriever_from: str = 'dev'): with flask_app.app_context(): try: if conversation: @@ -188,6 +192,7 @@ class CompletionService: conversation=conversation, streaming=streaming, is_override=is_model_config_override, + retriever_from=retriever_from ) except ConversationTaskStoppedException: pass @@ -400,7 +405,11 @@ class CompletionService: elif event == 'chain': yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n" elif event == 'agent_thought': - yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" + yield "data: " + json.dumps( + cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" + elif event == 'message_end': + yield "data: " + json.dumps( + cls.get_message_end_data(result.get('data'))) + "\n\n" elif event == 'ping': yield "event: ping\n\n" else: @@ -432,6 +441,20 @@ class CompletionService: return response_data + @classmethod + def get_message_end_data(cls, data: dict): + response_data = { + 'event': 'message_end', + 'task_id': data.get('task_id'), + 'id': data.get('message_id') + } + if 'retriever_resources' in data: + response_data['retriever_resources'] = data.get('retriever_resources') + if data.get('mode') == 'chat': + response_data['conversation_id'] = data.get('conversation_id') + + return response_data + @classmethod def get_chain_response_data(cls, data: dict): response_data = {