From cc2d71c253b7de01f036e37ebc68af64c4de082e Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 16 Aug 2023 20:48:42 +0800 Subject: [PATCH] feat: optimize override app model config convert (#874) --- api/controllers/console/app/app.py | 73 +++++++------------ api/controllers/console/app/model_config.py | 14 +--- api/core/agent/agent/structured_chat.py | 2 +- api/core/agent/agent_executor.py | 9 ++- api/core/conversation_message_task.py | 12 +-- api/core/generator/llm_generator.py | 17 +++-- api/core/orchestrator_rule_parser.py | 16 ++-- .../suggested_questions_after_answer.py | 10 ++- api/models/model.py | 50 ++++++++++++- api/services/completion_service.py | 46 +++--------- api/services/message_service.py | 42 ++++++++--- .../generate_conversation_summary_task.py | 16 ++-- 12 files changed, 166 insertions(+), 141 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 6d589cded5..c44f4edc62 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -124,12 +124,29 @@ class AppListApi(Resource): if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() + default_model = ModelFactory.get_default_model( + tenant_id=current_user.current_tenant_id, + model_type=ModelType.TEXT_GENERATION + ) + + if default_model: + default_model_provider = default_model.provider_name + default_model_name = default_model.model_name + else: + raise ProviderNotInitializeError( + f"No Text Generation Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + if args['model_config'] is not None: # validate config + model_config_dict = args['model_config'] + model_config_dict["model"]["provider"] = default_model_provider + model_config_dict["model"]["name"] = default_model_name + model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, account=current_user, - config=args['model_config'] + config=model_config_dict ) app = App( @@ -141,21 +158,8 @@ class AppListApi(Resource): status='normal' ) - app_model_config = AppModelConfig( - provider="", - model_id="", - configs={}, - opening_statement=model_configuration['opening_statement'], - suggested_questions=json.dumps(model_configuration['suggested_questions']), - suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']), - speech_to_text=json.dumps(model_configuration['speech_to_text']), - more_like_this=json.dumps(model_configuration['more_like_this']), - sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']), - model=json.dumps(model_configuration['model']), - user_input_form=json.dumps(model_configuration['user_input_form']), - pre_prompt=model_configuration['pre_prompt'], - agent_mode=json.dumps(model_configuration['agent_mode']), - ) + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_configuration) else: if 'mode' not in args or args['mode'] is None: abort(400, message="mode is required") @@ -165,20 +169,10 @@ class AppListApi(Resource): app = App(**model_config_template['app']) app_model_config = AppModelConfig(**model_config_template['model_config']) - default_model = ModelFactory.get_default_model( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_GENERATION - ) - - if default_model: - model_dict = app_model_config.model_dict - model_dict['provider'] = default_model.provider_name - model_dict['name'] = default_model.model_name - app_model_config.model = json.dumps(model_dict) - else: - raise ProviderNotInitializeError( - f"No Text Generation Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + model_dict = app_model_config.model_dict + model_dict['provider'] = default_model_provider + model_dict['name'] = default_model_name + app_model_config.model = json.dumps(model_dict) app.name = args['name'] app.mode = args['mode'] @@ -416,22 +410,9 @@ class AppCopy(Resource): @staticmethod def create_app_model_config_copy(app_config, copy_app_id): - copy_app_model_config = AppModelConfig( - app_id=copy_app_id, - provider=app_config.provider, - model_id=app_config.model_id, - configs=app_config.configs, - opening_statement=app_config.opening_statement, - suggested_questions=app_config.suggested_questions, - suggested_questions_after_answer=app_config.suggested_questions_after_answer, - speech_to_text=app_config.speech_to_text, - more_like_this=app_config.more_like_this, - sensitive_word_avoidance=app_config.sensitive_word_avoidance, - model=app_config.model, - user_input_form=app_config.user_input_form, - pre_prompt=app_config.pre_prompt, - agent_mode=app_config.agent_mode - ) + copy_app_model_config = app_config.copy() + copy_app_model_config.app_id = copy_app_id + return copy_app_model_config @setup_required diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index d0c648ba16..c8392e521f 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -35,20 +35,8 @@ class ModelConfigResource(Resource): new_app_model_config = AppModelConfig( app_id=app_model.id, - provider="", - model_id="", - configs={}, - opening_statement=model_configuration['opening_statement'], - suggested_questions=json.dumps(model_configuration['suggested_questions']), - suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']), - speech_to_text=json.dumps(model_configuration['speech_to_text']), - more_like_this=json.dumps(model_configuration['more_like_this']), - sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']), - model=json.dumps(model_configuration['model']), - user_input_form=json.dumps(model_configuration['user_input_form']), - pre_prompt=model_configuration['pre_prompt'], - agent_mode=json.dumps(model_configuration['agent_mode']), ) + new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) db.session.add(new_app_model_config) db.session.flush() diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index dc868590bf..c16214c632 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -112,7 +112,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): "I don't know how to respond to that."}, "") def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): - if len(intermediate_steps) >= 2: + if len(intermediate_steps) >= 2 and self.summary_llm: should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_messages = [AIMessage(content=observation) for _, observation in should_summary_intermediate_steps] diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index f345e631d3..ada5ab49a8 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -65,7 +65,8 @@ class AgentExecutor: llm=self.configuration.model_instance.client, tools=self.configuration.tools, output_parser=StructuredChatOutputParser(), - summary_llm=self.configuration.summary_model_instance.client, + summary_llm=self.configuration.summary_model_instance.client + if self.configuration.summary_model_instance else None, verbose=True ) elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: @@ -74,7 +75,8 @@ class AgentExecutor: llm=self.configuration.model_instance.client, tools=self.configuration.tools, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory - summary_llm=self.configuration.summary_model_instance.client, + summary_llm=self.configuration.summary_model_instance.client + if self.configuration.summary_model_instance else None, verbose=True ) elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: @@ -83,7 +85,8 @@ class AgentExecutor: llm=self.configuration.model_instance.client, tools=self.configuration.tools, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory - summary_llm=self.configuration.summary_model_instance.client, + summary_llm=self.configuration.summary_model_instance.client + if self.configuration.summary_model_instance else None, verbose=True ) elif self.configuration.strategy == PlanningStrategy.ROUTER: diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index e9d9f3ec80..099c19be27 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -60,17 +60,7 @@ class ConversationMessageTask: def init(self): override_model_configs = None if self.is_override: - override_model_configs = { - "model": self.app_model_config.model_dict, - "pre_prompt": self.app_model_config.pre_prompt, - "agent_mode": self.app_model_config.agent_mode_dict, - "opening_statement": self.app_model_config.opening_statement, - "suggested_questions": self.app_model_config.suggested_questions_list, - "suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict, - "more_like_this": self.app_model_config.more_like_this_dict, - "sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict, - "user_input_form": self.app_model_config.user_input_form_list, - } + override_model_configs = self.app_model_config.to_dict() introduction = '' system_instruction = '' diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 7e53e06600..91b324c631 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -2,7 +2,7 @@ import logging from langchain.schema import OutputParserException -from core.model_providers.error import LLMError +from core.model_providers.error import LLMError, ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.model_params import ModelKwargs @@ -108,13 +108,16 @@ class LLMGenerator: _input = prompt.format_prompt(histories=histories) - model_instance = ModelFactory.get_text_generation_model( - tenant_id=tenant_id, - model_kwargs=ModelKwargs( - max_tokens=256, - temperature=0 + try: + model_instance = ModelFactory.get_text_generation_model( + tenant_id=tenant_id, + model_kwargs=ModelKwargs( + max_tokens=256, + temperature=0 + ) ) - ) + except ProviderTokenNotInitError: + return [] prompts = [PromptMessage(content=_input.to_string())] diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 975cac2cd8..f4eed96ff5 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -14,6 +14,7 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.conversation_message_task import ConversationMessageTask +from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode from core.tool.dataset_retriever_tool import DatasetRetrieverTool @@ -78,13 +79,16 @@ class OrchestratorRuleParser: elif planning_strategy == PlanningStrategy.ROUTER: planning_strategy = PlanningStrategy.REACT_ROUTER - summary_model_instance = ModelFactory.get_text_generation_model( - tenant_id=self.tenant_id, - model_kwargs=ModelKwargs( - temperature=0, - max_tokens=500 + try: + summary_model_instance = ModelFactory.get_text_generation_model( + tenant_id=self.tenant_id, + model_kwargs=ModelKwargs( + temperature=0, + max_tokens=500 + ) ) - ) + except ProviderTokenNotInitError as e: + summary_model_instance = None tools = self.to_tools( tool_configs=tool_configs, diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/prompt/output_parser/suggested_questions_after_answer.py index 7898d08262..b7ec2b2944 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/prompt/output_parser/suggested_questions_after_answer.py @@ -1,7 +1,10 @@ import json +import re from typing import Any from langchain.schema import BaseOutputParser + +from core.model_providers.error import LLMError from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT @@ -12,5 +15,10 @@ class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): def parse(self, text: str) -> Any: json_string = text.strip() - json_obj = json.loads(json_string) + action_match = re.search(r".*(\[\".+\"\]).*", json_string, re.DOTALL) + if action_match is not None: + json_obj = json.loads(action_match.group(1).strip(), strict=False) + else: + raise LLMError("Could not parse LLM output: {text}") + return json_obj diff --git a/api/models/model.py b/api/models/model.py index ec6e392c1b..e8b3ec5e50 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -108,7 +108,7 @@ class AppModelConfig(db.Model): def suggested_questions_after_answer_dict(self) -> dict: return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \ else {"enabled": False} - + @property def speech_to_text_dict(self) -> dict: return json.loads(self.speech_to_text) if self.speech_to_text \ @@ -148,6 +148,46 @@ class AppModelConfig(db.Model): "agent_mode": self.agent_mode_dict } + def from_model_config_dict(self, model_config: dict): + self.provider = "" + self.model_id = "" + self.configs = {} + self.opening_statement = model_config['opening_statement'] + self.suggested_questions = json.dumps(model_config['suggested_questions']) + self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) + self.speech_to_text = json.dumps(model_config['speech_to_text']) \ + if model_config.get('speech_to_text') else None + self.more_like_this = json.dumps(model_config['more_like_this']) + self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) + self.model = json.dumps(model_config['model']) + self.user_input_form = json.dumps(model_config['user_input_form']) + self.pre_prompt = model_config['pre_prompt'] + self.agent_mode = json.dumps(model_config['agent_mode']) + + return self + + def copy(self): + new_app_model_config = AppModelConfig( + id=self.id, + app_id=self.app_id, + provider="", + model_id="", + configs={}, + opening_statement=self.opening_statement, + suggested_questions=self.suggested_questions, + suggested_questions_after_answer=self.suggested_questions_after_answer, + speech_to_text=self.speech_to_text, + more_like_this=self.more_like_this, + sensitive_word_avoidance=self.sensitive_word_avoidance, + model=self.model, + user_input_form=self.user_input_form, + pre_prompt=self.pre_prompt, + agent_mode=self.agent_mode + ) + + return new_app_model_config + + class RecommendedApp(db.Model): __tablename__ = 'recommended_apps' __table_args__ = ( @@ -234,7 +274,8 @@ class Conversation(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") + message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', + passive_deletes="all") is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -429,7 +470,7 @@ class Message(db.Model): @property def agent_thoughts(self): - return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id)\ + return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \ .order_by(MessageAgentThought.position.asc()).all() @@ -557,7 +598,8 @@ class Site(db.Model): @property def app_base_url(self): - return (current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/')) + return ( + current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/')) class ApiToken(db.Model): diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 8899cdc11b..3c05e99c9e 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -63,26 +63,23 @@ class CompletionService: raise ConversationCompletedError() if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id) + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id, + AppModelConfig.app_id == app_model.id + ).first() if not app_model_config: raise AppModelConfigBrokenError() else: conversation_override_model_configs = json.loads(conversation.override_model_configs) + app_model_config = AppModelConfig( id=conversation.app_model_config_id, app_id=app_model.id, - provider="", - model_id="", - configs="", - opening_statement=conversation_override_model_configs['opening_statement'], - suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']), - model=json.dumps(conversation_override_model_configs['model']), - user_input_form=json.dumps(conversation_override_model_configs['user_input_form']), - pre_prompt=conversation_override_model_configs['pre_prompt'], - agent_mode=json.dumps(conversation_override_model_configs['agent_mode']), ) + app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + if is_model_config_override: # build new app model config if 'model' not in args['model_config']: @@ -99,19 +96,8 @@ class CompletionService: app_model_config_model = app_model_config.model_dict app_model_config_model['completion_params'] = completion_params - app_model_config = AppModelConfig( - id=app_model_config.id, - app_id=app_model.id, - provider="", - model_id="", - configs="", - opening_statement=app_model_config.opening_statement, - suggested_questions=app_model_config.suggested_questions, - model=json.dumps(app_model_config_model), - user_input_form=app_model_config.user_input_form, - pre_prompt=app_model_config.pre_prompt, - agent_mode=app_model_config.agent_mode, - ) + app_model_config = app_model_config.copy() + app_model_config.model = json.dumps(app_model_config_model) else: if app_model.app_model_config_id is None: raise AppModelConfigBrokenError() @@ -135,20 +121,10 @@ class CompletionService: app_model_config = AppModelConfig( id=app_model_config.id, app_id=app_model.id, - provider="", - model_id="", - configs="", - opening_statement=model_config['opening_statement'], - suggested_questions=json.dumps(model_config['suggested_questions']), - suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']), - more_like_this=json.dumps(model_config['more_like_this']), - sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']), - model=json.dumps(model_config['model']), - user_input_form=json.dumps(model_config['user_input_form']), - pre_prompt=model_config['pre_prompt'], - agent_mode=json.dumps(model_config['agent_mode']), ) + app_model_config = app_model_config.from_model_config_dict(model_config) + # clean input by app_model_config form rules inputs = cls.get_cleaned_inputs(inputs, app_model_config) diff --git a/api/services/message_service.py b/api/services/message_service.py index 5c60017a97..fe205bcce0 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,3 +1,4 @@ +import json from typing import Optional, Union, List from core.completion import Completion @@ -5,8 +6,10 @@ from core.generator.llm_generator import LLMGenerator from libs.infinite_scroll_pagination import InfiniteScrollPagination from extensions.ext_database import db from models.account import Account -from models.model import App, EndUser, Message, MessageFeedback +from models.model import App, EndUser, Message, MessageFeedback, AppModelConfig from services.conversation_service import ConversationService +from services.errors.app_model_config import AppModelConfigBrokenError +from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \ SuggestedQuestionsAfterAnswerDisabledError @@ -172,12 +175,6 @@ class MessageService: if not user: raise ValueError('user cannot be None') - app_model_config = app_model.app_model_config - suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict - - if check_enabled and suggested_questions_after_answer.get("enabled", False) is False: - raise SuggestedQuestionsAfterAnswerDisabledError() - message = cls.get_message( app_model=app_model, user=user, @@ -190,10 +187,38 @@ class MessageService: user=user ) + if not conversation: + raise ConversationNotExistsError() + + if conversation.status != 'normal': + raise ConversationCompletedError() + + if not conversation.override_model_configs: + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id, + AppModelConfig.app_id == app_model.id + ).first() + + if not app_model_config: + raise AppModelConfigBrokenError() + else: + conversation_override_model_configs = json.loads(conversation.override_model_configs) + app_model_config = AppModelConfig( + id=conversation.app_model_config_id, + app_id=app_model.id, + ) + + app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + + suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict + + if check_enabled and suggested_questions_after_answer.get("enabled", False) is False: + raise SuggestedQuestionsAfterAnswerDisabledError() + # get memory of conversation (read-only) memory = Completion.get_memory_from_conversation( tenant_id=app_model.tenant_id, - app_model_config=app_model.app_model_config, + app_model_config=app_model_config, conversation=conversation, max_token_limit=3000, message_limit=3, @@ -209,4 +234,3 @@ class MessageService: ) return questions - diff --git a/api/tasks/generate_conversation_summary_task.py b/api/tasks/generate_conversation_summary_task.py index 40e5919717..791f141d5b 100644 --- a/api/tasks/generate_conversation_summary_task.py +++ b/api/tasks/generate_conversation_summary_task.py @@ -6,7 +6,7 @@ from celery import shared_task from werkzeug.exceptions import NotFound from core.generator.llm_generator import LLMGenerator -from core.model_providers.error import LLMError +from core.model_providers.error import LLMError, ProviderTokenNotInitError from extensions.ext_database import db from models.model import Conversation, Message @@ -40,10 +40,16 @@ def generate_conversation_summary_task(conversation_id: str): conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages) db.session.add(conversation) db.session.commit() - - end_at = time.perf_counter() - logging.info(click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), fg='green')) - except LLMError: + except (LLMError, ProviderTokenNotInitError): + conversation.summary = '[No Summary]' + db.session.commit() pass except Exception as e: + conversation.summary = '[No Summary]' + db.session.commit() logging.exception(e) + + end_at = time.perf_counter() + logging.info( + click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), + fg='green'))