diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py deleted file mode 100644 index 9c0f9c5b36..0000000000 --- a/api/core/agent/agent/calc_token_mixin.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import cast - -from core.entities.application_entities import ModelConfigEntity -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - - -class CalcTokenMixin: - - def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int: - """ - Got the rest tokens available for the model after excluding messages tokens and completion max tokens - - :param model_config: - :param messages: - :return: - """ - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 - - if model_context_tokens is None: - return 0 - - if max_tokens is None: - max_tokens = 0 - - prompt_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, - messages - ) - - rest_tokens = model_context_tokens - max_tokens - prompt_tokens - - return rest_tokens - - -class ExceededLLMTokensLimitError(Exception): - pass diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py deleted file mode 100644 index 1f2d5f24b3..0000000000 --- a/api/core/agent/agent/openai_function_call.py +++ /dev/null @@ -1,361 +0,0 @@ -from collections.abc import Sequence -from typing import Any, Optional, Union - -from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent -from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken -from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import ( - AgentAction, - AgentFinish, - AIMessage, - BaseMessage, - HumanMessage, - SystemMessage, - get_buffer_string, -) -from langchain.tools import BaseTool -from pydantic import root_validator - -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.third_party.langchain.llms.fake import FakeLLM - - -class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): - moving_summary_buffer: str = "" - moving_summary_index: int = 0 - summary_model_config: ModelConfigEntity = None - model_config: ModelConfigEntity - agent_llm_callback: Optional[AgentLLMCallback] = None - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @root_validator - def validate_llm(cls, values: dict) -> dict: - return values - - @classmethod - def from_llm_and_tools( - cls, - model_config: ModelConfigEntity, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, - system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." - ), - agent_llm_callback: Optional[AgentLLMCallback] = None, - **kwargs: Any, - ) -> BaseSingleActionAgent: - prompt = cls.create_prompt( - extra_prompt_messages=extra_prompt_messages, - system_message=system_message, - ) - return cls( - model_config=model_config, - llm=FakeLLM(response=''), - prompt=prompt, - tools=tools, - callback_manager=callback_manager, - agent_llm_callback=agent_llm_callback, - **kwargs, - ) - - def should_use_agent(self, query: str): - """ - return should use agent - - :param query: - :return: - """ - original_max_tokens = 0 - for parameter_rule in self.model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - original_max_tokens = (self.model_config.parameters.get(parameter_rule.name) - or self.model_config.parameters.get(parameter_rule.use_template)) or 0 - - self.model_config.parameters['max_tokens'] = 40 - - prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) - messages = prompt.to_messages() - - try: - prompt_messages = lc_messages_to_prompt_messages(messages) - model_instance = ModelInstance( - provider_model_bundle=self.model_config.provider_model_bundle, - model=self.model_config.model, - ) - - tools = [] - for function in self.functions: - tool = PromptMessageTool( - **function - ) - - tools.append(tool) - - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=tools, - stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - except Exception as e: - raise e - - self.model_config.parameters['max_tokens'] = original_max_tokens - - return True if result.message.tool_calls else False - - def plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) - selected_inputs = { - k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" - } - full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) - prompt = self.prompt.format_prompt(**full_inputs) - messages = prompt.to_messages() - - prompt_messages = lc_messages_to_prompt_messages(messages) - - # summarize messages if rest_tokens < 0 - try: - prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions) - except ExceededLLMTokensLimitError as e: - return AgentFinish(return_values={"output": str(e)}, log=str(e)) - - model_instance = ModelInstance( - provider_model_bundle=self.model_config.provider_model_bundle, - model=self.model_config.model, - ) - - tools = [] - for function in self.functions: - tool = PromptMessageTool( - **function - ) - - tools.append(tool) - - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=tools, - stream=False, - callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [], - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - - ai_message = AIMessage( - content=result.message.content or "", - additional_kwargs={ - 'function_call': { - 'id': result.message.tool_calls[0].id, - **result.message.tool_calls[0].function.dict() - } if result.message.tool_calls else None - } - ) - agent_decision = _parse_ai_message(ai_message) - - if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': - tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs: - tool_inputs['query'] = kwargs['input'] - agent_decision.tool_input = tool_inputs - - return agent_decision - - @classmethod - def get_system_message(cls): - return SystemMessage(content="You are a helpful AI assistant.\n" - "The current date or current time you know is wrong.\n" - "Respond directly if appropriate.") - - def return_stopped_response( - self, - early_stopping_method: str, - intermediate_steps: list[tuple[AgentAction, str]], - **kwargs: Any, - ) -> AgentFinish: - try: - return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) - except ValueError: - return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") - - def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]: - # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 - rest_tokens = self.get_message_rest_tokens( - self.model_config, - messages, - **kwargs - ) - - rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens - if rest_tokens >= 0: - return messages - - system_message = None - human_message = None - should_summary_messages = [] - for message in messages: - if isinstance(message, SystemMessage): - system_message = message - elif isinstance(message, HumanMessage): - human_message = message - else: - should_summary_messages.append(message) - - if len(should_summary_messages) > 2: - ai_message = should_summary_messages[-2] - function_message = should_summary_messages[-1] - should_summary_messages = should_summary_messages[self.moving_summary_index:-2] - self.moving_summary_index = len(should_summary_messages) - else: - error_msg = "Exceeded LLM tokens limit, stopped." - raise ExceededLLMTokensLimitError(error_msg) - - new_messages = [system_message, human_message] - - if self.moving_summary_index == 0: - should_summary_messages.insert(0, human_message) - - self.moving_summary_buffer = self.predict_new_summary( - messages=should_summary_messages, - existing_summary=self.moving_summary_buffer - ) - - new_messages.append(AIMessage(content=self.moving_summary_buffer)) - new_messages.append(ai_message) - new_messages.append(function_message) - - return new_messages - - def predict_new_summary( - self, messages: list[BaseMessage], existing_summary: str - ) -> str: - new_lines = get_buffer_string( - messages, - human_prefix="Human", - ai_prefix="AI", - ) - - chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) - return chain.predict(summary=existing_summary, new_lines=new_lines) - - def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if model_config.provider == 'azure_openai': - model = model_config.model - model = model.replace("gpt-35", "gpt-3.5") - else: - model = model_config.credentials.get("base_model_name") - - tiktoken_ = _import_tiktoken() - try: - encoding = tiktoken_.encoding_for_model(model) - except KeyError: - model = "cl100k_base" - encoding = tiktoken_.get_encoding(model) - - if model.startswith("gpt-3.5-turbo"): - # every message follows {role/name}\n{content}\n - tokens_per_message = 4 - # if there's a name, the role is omitted - tokens_per_name = -1 - elif model.startswith("gpt-4"): - tokens_per_message = 3 - tokens_per_name = 1 - else: - raise NotImplementedError( - f"get_num_tokens_from_messages() is not presently implemented " - f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " - "information on how messages are converted to tokens." - ) - num_tokens = 0 - for m in messages: - message = _convert_message_to_dict(m) - num_tokens += tokens_per_message - for key, value in message.items(): - if key == "function_call": - for f_key, f_value in value.items(): - num_tokens += len(encoding.encode(f_key)) - num_tokens += len(encoding.encode(f_value)) - else: - num_tokens += len(encoding.encode(value)) - - if key == "name": - num_tokens += tokens_per_name - # every reply is primed with assistant - num_tokens += 3 - - if kwargs.get('functions'): - for function in kwargs.get('functions'): - num_tokens += len(encoding.encode('name')) - num_tokens += len(encoding.encode(function.get("name"))) - num_tokens += len(encoding.encode('description')) - num_tokens += len(encoding.encode(function.get("description"))) - parameters = function.get("parameters") - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): - num_tokens += len(encoding.encode(key)) - for field_key, field_value in value.items(): - num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': - for enum_field in field_value: - num_tokens += 3 - num_tokens += len(encoding.encode(enum_field)) - else: - num_tokens += len(encoding.encode(field_key)) - num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: - num_tokens += 3 - num_tokens += len(encoding.encode(required_field)) - - return num_tokens diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py deleted file mode 100644 index e1be624204..0000000000 --- a/api/core/agent/agent/structured_chat.py +++ /dev/null @@ -1,306 +0,0 @@ -import re -from collections.abc import Sequence -from typing import Any, Optional, Union, cast - -from langchain import BasePromptTemplate, PromptTemplate -from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent -from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate -from langchain.schema import ( - AgentAction, - AgentFinish, - AIMessage, - BaseMessage, - HumanMessage, - OutputParserException, - get_buffer_string, -) -from langchain.tools import BaseTool - -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages - -FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. -Valid "action" values: "Final Answer" or {tool_names} - -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{{{{ - "action": $TOOL_NAME, - "action_input": $INPUT -}}}} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{{{{ - "action": "Final Answer", - "action_input": "Final response to human" -}}}} -```""" - - -class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): - moving_summary_buffer: str = "" - moving_summary_index: int = 0 - summary_model_config: ModelConfigEntity = None - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - def should_use_agent(self, query: str): - """ - return should use agent - Using the ReACT mode to determine whether an agent is needed is costly, - so it's better to just use an Agent for reasoning, which is cheaper. - - :param query: - :return: - """ - return True - - def plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, - along with observatons - callbacks: Callbacks to run. - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) - prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)]) - - messages = [] - if prompts: - messages = prompts[0].to_messages() - - prompt_messages = lc_messages_to_prompt_messages(messages) - - rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages) - if rest_tokens < 0: - full_inputs = self.summarize_messages(intermediate_steps, **kwargs) - - try: - full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) - except Exception as e: - raise e - - try: - agent_decision = self.output_parser.parse(full_output) - if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': - tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs: - tool_inputs['query'] = kwargs['input'] - agent_decision.tool_input = tool_inputs - return agent_decision - except OutputParserException: - return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " - "I don't know how to respond to that."}, "") - - def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs): - if len(intermediate_steps) >= 2 and self.summary_model_config: - should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] - should_summary_messages = [AIMessage(content=observation) - for _, observation in should_summary_intermediate_steps] - if self.moving_summary_index == 0: - should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input"))) - - self.moving_summary_index = len(intermediate_steps) - else: - error_msg = "Exceeded LLM tokens limit, stopped." - raise ExceededLLMTokensLimitError(error_msg) - - if self.moving_summary_buffer and 'chat_history' in kwargs: - kwargs["chat_history"].pop() - - self.moving_summary_buffer = self.predict_new_summary( - messages=should_summary_messages, - existing_summary=self.moving_summary_buffer - ) - - if 'chat_history' in kwargs: - kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer)) - - return self.get_full_inputs([intermediate_steps[-1]], **kwargs) - - def predict_new_summary( - self, messages: list[BaseMessage], existing_summary: str - ) -> str: - new_lines = get_buffer_string( - messages, - human_prefix="Human", - ai_prefix="AI", - ) - - chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) - return chain.predict(summary=existing_summary, new_lines=new_lines) - - @classmethod - def create_prompt( - cls, - tools: Sequence[BaseTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - memory_prompts: Optional[list[BasePromptTemplate]] = None, - ) -> BasePromptTemplate: - tool_strings = [] - for tool in tools: - args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) - tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") - formatted_tools = "\n".join(tool_strings) - tool_names = ", ".join([('"' + tool.name + '"') for tool in tools]) - format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - _memory_prompts = memory_prompts or [] - messages = [ - SystemMessagePromptTemplate.from_template(template), - *_memory_prompts, - HumanMessagePromptTemplate.from_template(human_message_template), - ] - return ChatPromptTemplate(input_variables=input_variables, messages=messages) - - @classmethod - def create_completion_prompt( - cls, - tools: Sequence[BaseTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - ) -> PromptTemplate: - """Create prompt in the style of the zero shot agent. - - Args: - tools: List of tools the agent will have access to, used to format the - prompt. - prefix: String to put before the list of tools. - input_variables: List of input variables the final prompt will expect. - - Returns: - A PromptTemplate with the template assembled from the pieces here. - """ - suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -Question: {input} -Thought: {agent_scratchpad} -""" - - tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) - tool_names = ", ".join([tool.name for tool in tools]) - format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - return PromptTemplate(template=template, input_variables=input_variables) - - def _construct_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] - ) -> str: - agent_scratchpad = "" - for action, observation in intermediate_steps: - agent_scratchpad += action.log - agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" - - if not isinstance(agent_scratchpad, str): - raise ValueError("agent_scratchpad should be of type string.") - if agent_scratchpad: - llm_chain = cast(LLMChain, self.llm_chain) - if llm_chain.model_config.mode == "chat": - return ( - f"This was your previous work " - f"(but I haven't seen any of it! I only see what " - f"you return as final answer):\n{agent_scratchpad}" - ) - else: - return agent_scratchpad - else: - return agent_scratchpad - - @classmethod - def from_llm_and_tools( - cls, - model_config: ModelConfigEntity, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - output_parser: Optional[AgentOutputParser] = None, - prefix: str = PREFIX, - suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - memory_prompts: Optional[list[BasePromptTemplate]] = None, - agent_llm_callback: Optional[AgentLLMCallback] = None, - **kwargs: Any, - ) -> Agent: - """Construct an agent from an LLM and tools.""" - cls._validate_tools(tools) - if model_config.mode == "chat": - prompt = cls.create_prompt( - tools, - prefix=prefix, - suffix=suffix, - human_message_template=human_message_template, - format_instructions=format_instructions, - input_variables=input_variables, - memory_prompts=memory_prompts, - ) - else: - prompt = cls.create_completion_prompt( - tools, - prefix=prefix, - format_instructions=format_instructions, - input_variables=input_variables, - ) - llm_chain = LLMChain( - model_config=model_config, - prompt=prompt, - callback_manager=callback_manager, - agent_llm_callback=agent_llm_callback, - parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - tool_names = [tool.name for tool in tools] - _output_parser = output_parser - return cls( - llm_chain=llm_chain, - allowed_tools=tool_names, - output_parser=_output_parser, - **kwargs, - ) diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index a4845d0ff1..d9a3447bda 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -1,4 +1,3 @@ -import json import logging from typing import cast @@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.moderation.base import ModerationException from core.tools.entities.tool_entities import ToolRuntimeVariablePool from extensions.ext_database import db -from models.model import App, Conversation, Message, MessageAgentThought, MessageChain +from models.model import App, Conversation, Message, MessageAgentThought from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) @@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner): # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) - - message_chain = self._init_message_chain( - message=message, - query=query - ) # init model instance model_instance = ModelInstance( @@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner): 'pool': db_variables.variables }) - def _init_message_chain(self, message: Message, query: str) -> MessageChain: - """ - Init MessageChain - :param message: message - :param query: query - :return: - """ - message_chain = MessageChain( - message_id=message.id, - type="AgentExecutor", - input=json.dumps({ - "input": query - }) - ) - - db.session.add(message_chain) - db.session.commit() - - return message_chain - - def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None: - """ - Save MessageChain - :param message_chain: message chain - :param output_text: output text - :return: - """ - message_chain.output = json.dumps({ - "output": output_text - }) - db.session.commit() - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, message: Message) -> LLMUsage: """ diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index e1972efb51..99df249ddf 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity -from core.features.dataset_retrieval import DatasetRetrievalFeature +from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py new file mode 100644 index 0000000000..0cdf8670c4 --- /dev/null +++ b/api/core/entities/agent_entities.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class PlanningStrategy(Enum): + ROUTER = 'router' + REACT_ROUTER = 'react_router' + REACT = 'react' + FUNCTION_CALL = 'function_call' diff --git a/api/core/features/agent_runner.py b/api/core/features/agent_runner.py deleted file mode 100644 index 7412d81281..0000000000 --- a/api/core/features/agent_runner.py +++ /dev/null @@ -1,199 +0,0 @@ -import logging -from typing import Optional, cast - -from langchain.tools import BaseTool - -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy -from core.application_queue_manager import ApplicationQueueManager -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.entities.application_entities import ( - AgentEntity, - AppOrchestrationConfigEntity, - InvokeFrom, - ModelConfigEntity, -) -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.model_providers import model_provider_factory -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool -from extensions.ext_database import db -from models.dataset import Dataset -from models.model import Message - -logger = logging.getLogger(__name__) - - -class AgentRunnerFeature: - def __init__(self, tenant_id: str, - app_orchestration_config: AppOrchestrationConfigEntity, - model_config: ModelConfigEntity, - config: AgentEntity, - queue_manager: ApplicationQueueManager, - message: Message, - user_id: str, - agent_llm_callback: AgentLLMCallback, - callback: AgentLoopGatherCallbackHandler, - memory: Optional[TokenBufferMemory] = None,) -> None: - """ - Agent runner - :param tenant_id: tenant id - :param app_orchestration_config: app orchestration config - :param model_config: model config - :param config: dataset config - :param queue_manager: queue manager - :param message: message - :param user_id: user id - :param agent_llm_callback: agent llm callback - :param callback: callback - :param memory: memory - """ - self.tenant_id = tenant_id - self.app_orchestration_config = app_orchestration_config - self.model_config = model_config - self.config = config - self.queue_manager = queue_manager - self.message = message - self.user_id = user_id - self.agent_llm_callback = agent_llm_callback - self.callback = callback - self.memory = memory - - def run(self, query: str, - invoke_from: InvokeFrom) -> Optional[str]: - """ - Retrieve agent loop result. - :param query: query - :param invoke_from: invoke from - :return: - """ - provider = self.config.provider - model = self.config.model - tool_configs = self.config.tools - - # check model is support tool calling - provider_instance = model_provider_factory.get_provider_instance(provider=provider) - model_type_instance = provider_instance.get_model_instance(ModelType.LLM) - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - # get model schema - model_schema = model_type_instance.get_model_schema( - model=model, - credentials=self.model_config.credentials - ) - - if not model_schema: - return None - - planning_strategy = PlanningStrategy.REACT - features = model_schema.features - if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: - planning_strategy = PlanningStrategy.FUNCTION_CALL - - tools = self.to_tools( - tool_configs=tool_configs, - invoke_from=invoke_from, - callbacks=[self.callback, DifyStdOutCallbackHandler()], - ) - - if len(tools) == 0: - return None - - agent_configuration = AgentConfiguration( - strategy=planning_strategy, - model_config=self.model_config, - tools=tools, - memory=self.memory, - max_iterations=10, - max_execution_time=400.0, - early_stopping_method="generate", - agent_llm_callback=self.agent_llm_callback, - callbacks=[self.callback, DifyStdOutCallbackHandler()] - ) - - agent_executor = AgentExecutor(agent_configuration) - - try: - # check if should use agent - should_use_agent = agent_executor.should_use_agent(query) - if not should_use_agent: - return None - - result = agent_executor.run(query) - return result.output - except Exception as ex: - logger.exception("agent_executor run failed") - return None - - def to_dataset_retriever_tool(self, tool_config: dict, - invoke_from: InvokeFrom) \ - -> Optional[BaseTool]: - """ - A dataset tool is a tool that can be used to retrieve information from a dataset - :param tool_config: tool config - :param invoke_from: invoke from - """ - show_retrieve_source = self.app_orchestration_config.show_retrieve_source - - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager=self.queue_manager, - app_id=self.message.app_id, - message_id=self.message.id, - user_id=self.user_id, - invoke_from=invoke_from - ) - - # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == tool_config.get("id") - ).first() - - # pass if dataset is not available - if not dataset: - return None - - # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): - return None - - # get retrieval model config - default_retrieval_model = { - 'search_method': 'semantic_search', - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False - } - - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model - - # get top k - top_k = retrieval_model_config['top_k'] - - # get score threshold - score_threshold = None - score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") - if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") - - tool = DatasetRetrieverTool.from_dataset( - dataset=dataset, - top_k=top_k, - score_threshold=score_threshold, - hit_callbacks=[hit_callback], - return_resource=show_retrieve_source, - retriever_from=invoke_from.to_source() - ) - - return tool \ No newline at end of file diff --git a/api/core/third_party/langchain/llms/__init__.py b/api/core/features/dataset_retrieval/__init__.py similarity index 100% rename from api/core/third_party/langchain/llms/__init__.py rename to api/core/features/dataset_retrieval/__init__.py diff --git a/api/core/third_party/spark/__init__.py b/api/core/features/dataset_retrieval/agent/__init__.py similarity index 100% rename from api/core/third_party/spark/__init__.py rename to api/core/features/dataset_retrieval/agent/__init__.py diff --git a/api/core/agent/agent/agent_llm_callback.py b/api/core/features/dataset_retrieval/agent/agent_llm_callback.py similarity index 100% rename from api/core/agent/agent/agent_llm_callback.py rename to api/core/features/dataset_retrieval/agent/agent_llm_callback.py diff --git a/api/core/third_party/langchain/llms/fake.py b/api/core/features/dataset_retrieval/agent/fake_llm.py similarity index 100% rename from api/core/third_party/langchain/llms/fake.py rename to api/core/features/dataset_retrieval/agent/fake_llm.py diff --git a/api/core/chain/llm_chain.py b/api/core/features/dataset_retrieval/agent/llm_chain.py similarity index 91% rename from api/core/chain/llm_chain.py rename to api/core/features/dataset_retrieval/agent/llm_chain.py index 86fb156292..e5155e15a0 100644 --- a/api/core/chain/llm_chain.py +++ b/api/core/features/dataset_retrieval/agent/llm_chain.py @@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.schema import Generation, LLMResult from langchain.schema.language_model import BaseLanguageModel -from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages +from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.features.dataset_retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance -from core.third_party.langchain.llms.fake import FakeLLM class LLMChain(LCLLMChain): diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py similarity index 98% rename from api/core/agent/agent/multi_dataset_router_agent.py rename to api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py index eb594c3d21..59923202fd 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py @@ -12,9 +12,9 @@ from pydantic import root_validator from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages +from core.features.dataset_retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool -from core.third_party.langchain.llms.fake import FakeLLM class MultiDatasetRouterAgent(OpenAIFunctionsAgent): diff --git a/api/core/data_loader/file_extractor.py b/api/core/features/dataset_retrieval/agent/output_parser/__init__.py similarity index 100% rename from api/core/data_loader/file_extractor.py rename to api/core/features/dataset_retrieval/agent/output_parser/__init__.py diff --git a/api/core/agent/agent/output_parser/structured_chat.py b/api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py similarity index 100% rename from api/core/agent/agent/output_parser/structured_chat.py rename to api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py similarity index 99% rename from api/core/agent/agent/structed_multi_dataset_router_agent.py rename to api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py index e104bb01f9..e69302bfd6 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py @@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool -from core.chain.llm_chain import LLMChain from core.entities.application_entities import ModelConfigEntity +from core.features.dataset_retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. diff --git a/api/core/agent/agent_executor.py b/api/core/features/dataset_retrieval/agent_based_dataset_executor.py similarity index 69% rename from api/core/agent/agent_executor.py rename to api/core/features/dataset_retrieval/agent_based_dataset_executor.py index 70fe00ee13..588ccc91f5 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/features/dataset_retrieval/agent_based_dataset_executor.py @@ -1,4 +1,3 @@ -import enum import logging from typing import Optional, Union @@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks from langchain.tools import BaseTool from pydantic import BaseModel, Extra -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent -from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent -from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent -from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent +from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages +from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent +from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser +from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError @@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool -class PlanningStrategy(str, enum.Enum): - ROUTER = 'router' - REACT_ROUTER = 'react_router' - REACT = 'react' - FUNCTION_CALL = 'function_call' - - class AgentConfiguration(BaseModel): strategy: PlanningStrategy model_config: ModelConfigEntity @@ -62,28 +53,7 @@ class AgentExecutor: self.agent = self._init_agent() def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: - if self.configuration.strategy == PlanningStrategy.REACT: - agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( - model_config=self.configuration.model_config, - tools=self.configuration.tools, - output_parser=StructuredChatOutputParser(), - summary_model_config=self.configuration.summary_model_config - if self.configuration.summary_model_config else None, - agent_llm_callback=self.configuration.agent_llm_callback, - verbose=True - ) - elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: - agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( - model_config=self.configuration.model_config, - tools=self.configuration.tools, - extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) - if self.configuration.memory else None, # used for read chat histories memory - summary_model_config=self.configuration.summary_model_config - if self.configuration.summary_model_config else None, - agent_llm_callback=self.configuration.agent_llm_callback, - verbose=True - ) - elif self.configuration.strategy == PlanningStrategy.ROUTER: + if self.configuration.strategy == PlanningStrategy.ROUTER: self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval/dataset_retrieval.py similarity index 97% rename from api/core/features/dataset_retrieval.py rename to api/core/features/dataset_retrieval/dataset_retrieval.py index 488a8ca8d0..3e54d8644d 100644 --- a/api/core/features/dataset_retrieval.py +++ b/api/core/features/dataset_retrieval/dataset_retrieval.py @@ -2,9 +2,10 @@ from typing import Optional, cast from langchain.tools import BaseTool -from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity +from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py deleted file mode 100644 index 5c97bba530..0000000000 --- a/api/core/third_party/spark/spark_llm.py +++ /dev/null @@ -1,189 +0,0 @@ -import base64 -import hashlib -import hmac -import json -import queue -import ssl -from datetime import datetime -from time import mktime -from typing import Optional -from urllib.parse import urlencode, urlparse -from wsgiref.handlers import format_date_time - -import websocket - - -class SparkLLMClient: - def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - domain = 'spark-api.xf-yun.com' - endpoint = 'chat' - if api_domain: - domain = api_domain - if model_name == 'spark-v3': - endpoint = 'multimodal' - - model_api_configs = { - 'spark': { - 'version': 'v1.1', - 'chat_domain': 'general' - }, - 'spark-v2': { - 'version': 'v2.1', - 'chat_domain': 'generalv2' - }, - 'spark-v3': { - 'version': 'v3.1', - 'chat_domain': 'generalv3' - }, - 'spark-v3.5': { - 'version': 'v3.5', - 'chat_domain': 'generalv3.5' - } - } - - api_version = model_api_configs[model_name]['version'] - - self.chat_domain = model_api_configs[model_name]['chat_domain'] - self.api_base = f"wss://{domain}/{api_version}/{endpoint}" - self.app_id = app_id - self.ws_url = self.create_url( - urlparse(self.api_base).netloc, - urlparse(self.api_base).path, - self.api_base, - api_key, - api_secret - ) - - self.queue = queue.Queue() - self.blocking_message = '' - - def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: - # generate timestamp by RFC1123 - now = datetime.now() - date = format_date_time(mktime(now.timetuple())) - - signature_origin = "host: " + host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + path + " HTTP/1.1" - - # encrypt using hmac-sha256 - signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - - v = { - "authorization": authorization, - "date": date, - "host": host - } - # generate url - url = api_base + '?' + urlencode(v) - return url - - def run(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None, streaming: bool = False): - websocket.enableTrace(False) - ws = websocket.WebSocketApp( - self.ws_url, - on_message=self.on_message, - on_error=self.on_error, - on_close=self.on_close, - on_open=self.on_open - ) - ws.messages = messages - ws.user_id = user_id - ws.model_kwargs = model_kwargs - ws.streaming = streaming - ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) - - def on_error(self, ws, error): - self.queue.put({ - 'status_code': error.status_code, - 'error': error.resp_body.decode('utf-8') - }) - ws.close() - - def on_close(self, ws, close_status_code, close_reason): - self.queue.put({'done': True}) - - def on_open(self, ws): - self.blocking_message = '' - data = json.dumps(self.gen_params( - messages=ws.messages, - user_id=ws.user_id, - model_kwargs=ws.model_kwargs - )) - ws.send(data) - - def on_message(self, ws, message): - data = json.loads(message) - code = data['header']['code'] - if code != 0: - self.queue.put({ - 'status_code': 400, - 'error': f"Code: {code}, Error: {data['header']['message']}" - }) - ws.close() - else: - choices = data["payload"]["choices"] - status = choices["status"] - content = choices["text"][0]["content"] - if ws.streaming: - self.queue.put({'data': content}) - else: - self.blocking_message += content - - if status == 2: - if not ws.streaming: - self.queue.put({'data': self.blocking_message}) - ws.close() - - def gen_params(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None) -> dict: - data = { - "header": { - "app_id": self.app_id, - "uid": user_id - }, - "parameter": { - "chat": { - "domain": self.chat_domain - } - }, - "payload": { - "message": { - "text": messages - } - } - } - - if model_kwargs: - data['parameter']['chat'].update(model_kwargs) - - return data - - def subscribe(self): - while True: - content = self.queue.get() - if 'error' in content: - if content['status_code'] == 401: - raise SparkError('[Spark] The credentials you provided are incorrect. ' - 'Please double-check and fill them in again.') - elif content['status_code'] == 403: - raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " - "Please try again after obtaining the necessary permissions.") - else: - raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") - - if 'data' not in content: - break - yield content - - -class SparkError(Exception): - pass diff --git a/api/core/tool/current_datetime_tool.py b/api/core/tool/current_datetime_tool.py deleted file mode 100644 index 208490a5bf..0000000000 --- a/api/core/tool/current_datetime_tool.py +++ /dev/null @@ -1,24 +0,0 @@ -from datetime import datetime - -from langchain.tools import BaseTool -from pydantic import BaseModel, Field - - -class DatetimeToolInput(BaseModel): - type: str = Field(..., description="Type for current time, must be: datetime.") - - -class DatetimeTool(BaseTool): - """Tool for querying current datetime.""" - name: str = "current_datetime" - args_schema: type[BaseModel] = DatetimeToolInput - description: str = "A tool when you want to get the current date, time, week, month or year, " \ - "and the time zone is UTC. Result is \"