From b6de97ad530a4b362f2527b249e97d4d06b69c41 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 10 Apr 2024 20:37:22 +0800 Subject: [PATCH] Remove langchain dataset retrival agent logic (#3311) --- api/core/app/apps/chat/app_runner.py | 2 + api/core/app/apps/completion/app_runner.py | 2 + api/core/rag/retrieval/agent/fake_llm.py | 59 --- api/core/rag/retrieval/agent/llm_chain.py | 46 --- .../agent/multi_dataset_router_agent.py | 179 -------- .../retrieval/agent/output_parser/__init__.py | 0 .../structed_multi_dataset_router_agent.py | 259 ------------ .../retrieval/agent_based_dataset_executor.py | 117 ------ api/core/rag/retrieval/dataset_retrieval.py | 381 +++++++++++++----- .../{agent => output_parser}/__init__.py | 0 .../output_parser/structured_chat.py | 0 .../multi_dataset_function_call_router.py | 0 .../router}/multi_dataset_react_route.py | 22 +- .../knowledge_retrieval_node.py | 240 ++--------- 14 files changed, 341 insertions(+), 966 deletions(-) delete mode 100644 api/core/rag/retrieval/agent/fake_llm.py delete mode 100644 api/core/rag/retrieval/agent/llm_chain.py delete mode 100644 api/core/rag/retrieval/agent/multi_dataset_router_agent.py delete mode 100644 api/core/rag/retrieval/agent/output_parser/__init__.py delete mode 100644 api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py delete mode 100644 api/core/rag/retrieval/agent_based_dataset_executor.py rename api/core/rag/retrieval/{agent => output_parser}/__init__.py (100%) rename api/core/rag/retrieval/{agent => }/output_parser/structured_chat.py (100%) rename api/core/{workflow/nodes/knowledge_retrieval => rag/retrieval/router}/multi_dataset_function_call_router.py (100%) rename api/core/{workflow/nodes/knowledge_retrieval => rag/retrieval/router}/multi_dataset_react_route.py (90%) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index d51f3db540..ba2095076f 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -156,6 +156,8 @@ class ChatAppRunner(AppRunner): dataset_retrieval = DatasetRetrieval() context = dataset_retrieval.retrieve( + app_id=app_record.id, + user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_config, config=app_config.dataset, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 649d73d961..40102f8999 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -116,6 +116,8 @@ class CompletionAppRunner(AppRunner): dataset_retrieval = DatasetRetrieval() context = dataset_retrieval.retrieve( + app_id=app_record.id, + user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_config, config=dataset_config, diff --git a/api/core/rag/retrieval/agent/fake_llm.py b/api/core/rag/retrieval/agent/fake_llm.py deleted file mode 100644 index ab5152b38d..0000000000 --- a/api/core/rag/retrieval/agent/fake_llm.py +++ /dev/null @@ -1,59 +0,0 @@ -import time -from collections.abc import Mapping -from typing import Any, Optional - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import SimpleChatModel -from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult - - -class FakeLLM(SimpleChatModel): - """Fake ChatModel for testing purposes.""" - - streaming: bool = False - """Whether to stream the results or not.""" - response: str - - @property - def _llm_type(self) -> str: - return "fake-chat-model" - - def _call( - self, - messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """First try to lookup in queries, else return 'foo' or 'bar'.""" - return self.response - - @property - def _identifying_params(self) -> Mapping[str, Any]: - return {"response": self.response} - - def get_num_tokens(self, text: str) -> int: - return 0 - - def _generate( - self, - messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) - if self.streaming: - for token in output_str: - if run_manager: - run_manager.on_llm_new_token(token) - time.sleep(0.01) - - message = AIMessage(content=output_str) - generation = ChatGeneration(message=message) - llm_output = {"token_usage": { - 'prompt_tokens': 0, - 'completion_tokens': 0, - 'total_tokens': 0, - }} - return ChatResult(generations=[generation], llm_output=llm_output) diff --git a/api/core/rag/retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py deleted file mode 100644 index f2c5d4ca33..0000000000 --- a/api/core/rag/retrieval/agent/llm_chain.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, Optional - -from langchain import LLMChain as LCLLMChain -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.schema import Generation, LLMResult -from langchain.schema.language_model import BaseLanguageModel - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_manager import ModelInstance -from core.rag.retrieval.agent.fake_llm import FakeLLM - - -class LLMChain(LCLLMChain): - model_config: ModelConfigWithCredentialsEntity - """The language model instance to use.""" - llm: BaseLanguageModel = FakeLLM(response="") - parameters: dict[str, Any] = {} - - def generate( - self, - input_list: list[dict[str, Any]], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> LLMResult: - """Generate LLM result from inputs.""" - prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) - messages = prompts[0].to_messages() - prompt_messages = lc_messages_to_prompt_messages(messages) - - model_instance = ModelInstance( - provider_model_bundle=self.model_config.provider_model_bundle, - model=self.model_config.model, - ) - - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - stream=False, - stop=stop, - model_parameters=self.parameters - ) - - generations = [ - [Generation(text=result.message.content)] - ] - - return LLMResult(generations=generations) diff --git a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py deleted file mode 100644 index be24731d46..0000000000 --- a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py +++ /dev/null @@ -1,179 +0,0 @@ -from collections.abc import Sequence -from typing import Any, Optional, Union - -from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent -from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage -from langchain.tools import BaseTool -from pydantic import root_validator - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessageTool -from core.rag.retrieval.agent.fake_llm import FakeLLM - - -class MultiDatasetRouterAgent(OpenAIFunctionsAgent): - """ - An Multi Dataset Retrieve Agent driven by Router. - """ - model_config: ModelConfigWithCredentialsEntity - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @root_validator - def validate_llm(cls, values: dict) -> dict: - return values - - def should_use_agent(self, query: str): - """ - return should use agent - - :param query: - :return: - """ - return True - - def plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - if len(self.tools) == 0: - return AgentFinish(return_values={"output": ''}, log='') - elif len(self.tools) == 1: - tool = next(iter(self.tools)) - rst = tool.run(tool_input={'query': kwargs['input']}) - # output = '' - # rst_json = json.loads(rst) - # for item in rst_json: - # output += f'{item["content"]}\n' - return AgentFinish(return_values={"output": rst}, log=rst) - - if intermediate_steps: - _, observation = intermediate_steps[-1] - return AgentFinish(return_values={"output": observation}, log=observation) - - try: - agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) - if isinstance(agent_decision, AgentAction): - tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: - tool_inputs['query'] = kwargs['input'] - agent_decision.tool_input = tool_inputs - else: - agent_decision.return_values['output'] = '' - return agent_decision - except Exception as e: - raise e - - def real_plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) - selected_inputs = { - k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" - } - full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) - prompt = self.prompt.format_prompt(**full_inputs) - messages = prompt.to_messages() - prompt_messages = lc_messages_to_prompt_messages(messages) - - model_instance = ModelInstance( - provider_model_bundle=self.model_config.provider_model_bundle, - model=self.model_config.model, - ) - - tools = [] - for function in self.functions: - tool = PromptMessageTool( - **function - ) - - tools.append(tool) - - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=tools, - stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - - ai_message = AIMessage( - content=result.message.content or "", - additional_kwargs={ - 'function_call': { - 'id': result.message.tool_calls[0].id, - **result.message.tool_calls[0].function.dict() - } if result.message.tool_calls else None - } - ) - - agent_decision = _parse_ai_message(ai_message) - return agent_decision - - async def aplan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - raise NotImplementedError() - - @classmethod - def from_llm_and_tools( - cls, - model_config: ModelConfigWithCredentialsEntity, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, - system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." - ), - **kwargs: Any, - ) -> BaseSingleActionAgent: - prompt = cls.create_prompt( - extra_prompt_messages=extra_prompt_messages, - system_message=system_message, - ) - return cls( - model_config=model_config, - llm=FakeLLM(response=''), - prompt=prompt, - tools=tools, - callback_manager=callback_manager, - **kwargs, - ) diff --git a/api/core/rag/retrieval/agent/output_parser/__init__.py b/api/core/rag/retrieval/agent/output_parser/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py deleted file mode 100644 index 7035ec8e2f..0000000000 --- a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py +++ /dev/null @@ -1,259 +0,0 @@ -import re -from collections.abc import Sequence -from typing import Any, Optional, Union, cast - -from langchain import BasePromptTemplate, PromptTemplate -from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent -from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, OutputParserException -from langchain.tools import BaseTool - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.rag.retrieval.agent.llm_chain import LLMChain - -FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. -Valid "action" values: "Final Answer" or {tool_names} - -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{{{{ - "action": $TOOL_NAME, - "action_input": $INPUT -}}}} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{{{{ - "action": "Final Answer", - "action_input": "Final response to human" -}}}} -```""" - - -class StructuredMultiDatasetRouterAgent(StructuredChatAgent): - dataset_tools: Sequence[BaseTool] - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - def should_use_agent(self, query: str): - """ - return should use agent - Using the ReACT mode to determine whether an agent is needed is costly, - so it's better to just use an Agent for reasoning, which is cheaper. - - :param query: - :return: - """ - return True - - def plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, - along with observations - callbacks: Callbacks to run. - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - if len(self.dataset_tools) == 0: - return AgentFinish(return_values={"output": ''}, log='') - elif len(self.dataset_tools) == 1: - tool = next(iter(self.dataset_tools)) - rst = tool.run(tool_input={'query': kwargs['input']}) - return AgentFinish(return_values={"output": rst}, log=rst) - - if intermediate_steps: - _, observation = intermediate_steps[-1] - return AgentFinish(return_values={"output": observation}, log=observation) - - full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) - - try: - full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) - except Exception as e: - raise e - - try: - agent_decision = self.output_parser.parse(full_output) - if isinstance(agent_decision, AgentAction): - tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs: - tool_inputs['query'] = kwargs['input'] - agent_decision.tool_input = tool_inputs - elif isinstance(tool_inputs, str): - agent_decision.tool_input = kwargs['input'] - else: - agent_decision.return_values['output'] = '' - return agent_decision - except OutputParserException: - return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " - "I don't know how to respond to that."}, "") - - @classmethod - def create_prompt( - cls, - tools: Sequence[BaseTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - memory_prompts: Optional[list[BasePromptTemplate]] = None, - ) -> BasePromptTemplate: - tool_strings = [] - for tool in tools: - args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) - tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") - formatted_tools = "\n".join(tool_strings) - unique_tool_names = set(tool.name for tool in tools) - tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) - format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - _memory_prompts = memory_prompts or [] - messages = [ - SystemMessagePromptTemplate.from_template(template), - *_memory_prompts, - HumanMessagePromptTemplate.from_template(human_message_template), - ] - return ChatPromptTemplate(input_variables=input_variables, messages=messages) - - @classmethod - def create_completion_prompt( - cls, - tools: Sequence[BaseTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - ) -> PromptTemplate: - """Create prompt in the style of the zero shot agent. - - Args: - tools: List of tools the agent will have access to, used to format the - prompt. - prefix: String to put before the list of tools. - input_variables: List of input variables the final prompt will expect. - - Returns: - A PromptTemplate with the template assembled from the pieces here. - """ - suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -Question: {input} -Thought: {agent_scratchpad} -""" - - tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) - tool_names = ", ".join([tool.name for tool in tools]) - format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - return PromptTemplate(template=template, input_variables=input_variables) - - def _construct_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] - ) -> str: - agent_scratchpad = "" - for action, observation in intermediate_steps: - agent_scratchpad += action.log - agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" - - if not isinstance(agent_scratchpad, str): - raise ValueError("agent_scratchpad should be of type string.") - if agent_scratchpad: - llm_chain = cast(LLMChain, self.llm_chain) - if llm_chain.model_config.mode == "chat": - return ( - f"This was your previous work " - f"(but I haven't seen any of it! I only see what " - f"you return as final answer):\n{agent_scratchpad}" - ) - else: - return agent_scratchpad - else: - return agent_scratchpad - - @classmethod - def from_llm_and_tools( - cls, - model_config: ModelConfigWithCredentialsEntity, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - output_parser: Optional[AgentOutputParser] = None, - prefix: str = PREFIX, - suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - memory_prompts: Optional[list[BasePromptTemplate]] = None, - **kwargs: Any, - ) -> Agent: - """Construct an agent from an LLM and tools.""" - cls._validate_tools(tools) - if model_config.mode == "chat": - prompt = cls.create_prompt( - tools, - prefix=prefix, - suffix=suffix, - human_message_template=human_message_template, - format_instructions=format_instructions, - input_variables=input_variables, - memory_prompts=memory_prompts, - ) - else: - prompt = cls.create_completion_prompt( - tools, - prefix=prefix, - format_instructions=format_instructions, - input_variables=input_variables - ) - - llm_chain = LLMChain( - model_config=model_config, - prompt=prompt, - callback_manager=callback_manager, - parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - tool_names = [tool.name for tool in tools] - _output_parser = output_parser - return cls( - llm_chain=llm_chain, - allowed_tools=tool_names, - output_parser=_output_parser, - dataset_tools=tools, - **kwargs, - ) diff --git a/api/core/rag/retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py deleted file mode 100644 index cb475bcffb..0000000000 --- a/api/core/rag/retrieval/agent_based_dataset_executor.py +++ /dev/null @@ -1,117 +0,0 @@ -import logging -from typing import Optional, Union - -from langchain.agents import AgentExecutor as LCAgentExecutor -from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent -from langchain.callbacks.manager import Callbacks -from langchain.tools import BaseTool -from pydantic import BaseModel, Extra - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.agent_entities import PlanningStrategy -from core.entities.message_entities import prompt_messages_to_lc_messages -from core.helper import moderation -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.errors.invoke import InvokeError -from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent -from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent -from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool -from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool - - -class AgentConfiguration(BaseModel): - strategy: PlanningStrategy - model_config: ModelConfigWithCredentialsEntity - tools: list[BaseTool] - summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None - memory: Optional[TokenBufferMemory] = None - callbacks: Callbacks = None - max_iterations: int = 6 - max_execution_time: Optional[float] = None - early_stopping_method: str = "generate" - # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - -class AgentExecuteResult(BaseModel): - strategy: PlanningStrategy - output: Optional[str] - configuration: AgentConfiguration - - -class AgentExecutor: - def __init__(self, configuration: AgentConfiguration): - self.configuration = configuration - self.agent = self._init_agent() - - def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: - if self.configuration.strategy == PlanningStrategy.ROUTER: - self.configuration.tools = [t for t in self.configuration.tools - if isinstance(t, DatasetRetrieverTool) - or isinstance(t, DatasetMultiRetrieverTool)] - agent = MultiDatasetRouterAgent.from_llm_and_tools( - model_config=self.configuration.model_config, - tools=self.configuration.tools, - extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) - if self.configuration.memory else None, - verbose=True - ) - elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: - self.configuration.tools = [t for t in self.configuration.tools - if isinstance(t, DatasetRetrieverTool) - or isinstance(t, DatasetMultiRetrieverTool)] - agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( - model_config=self.configuration.model_config, - tools=self.configuration.tools, - output_parser=StructuredChatOutputParser(), - verbose=True - ) - else: - raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") - - return agent - - def should_use_agent(self, query: str) -> bool: - return self.agent.should_use_agent(query) - - def run(self, query: str) -> AgentExecuteResult: - moderation_result = moderation.check_moderation( - self.configuration.model_config, - query - ) - - if moderation_result: - return AgentExecuteResult( - output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", - strategy=self.configuration.strategy, - configuration=self.configuration - ) - - agent_executor = LCAgentExecutor.from_agent_and_tools( - agent=self.agent, - tools=self.configuration.tools, - max_iterations=self.configuration.max_iterations, - max_execution_time=self.configuration.max_execution_time, - early_stopping_method=self.configuration.early_stopping_method, - callbacks=self.configuration.callbacks - ) - - try: - output = agent_executor.run(input=query) - except InvokeError as ex: - raise ex - except Exception as ex: - logging.exception("agent_executor run failed") - output = None - - return AgentExecuteResult( - output=output, - strategy=self.configuration.strategy, - configuration=self.configuration - ) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index ee72842326..5c2d486656 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,23 +1,40 @@ +import threading from typing import Optional, cast -from langchain.tools import BaseTool +from flask import Flask, current_app from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.model_entities import ModelFeature +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import PromptMessageTool +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor -from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool -from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter +from core.rerank.rerank import RerankRunner from extensions.ext_database import db -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import Document as DatasetDocument + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} class DatasetRetrieval: - def retrieve(self, tenant_id: str, + def retrieve(self, app_id: str, user_id: str, tenant_id: str, model_config: ModelConfigWithCredentialsEntity, config: DatasetEntity, query: str, @@ -27,6 +44,8 @@ class DatasetRetrieval: memory: Optional[TokenBufferMemory] = None) -> Optional[str]: """ Retrieve dataset. + :param app_id: app_id + :param user_id: user_id :param tenant_id: tenant id :param model_config: model config :param config: dataset config @@ -38,12 +57,22 @@ class DatasetRetrieval: :return: """ dataset_ids = config.dataset_ids + if len(dataset_ids) == 0: + return None retrieve_config = config.retrieve_config # check model is support tool calling model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.provider, + model=model_config.model + ) + # get model schema model_schema = model_type_instance.get_model_schema( model=model_config.model, @@ -59,56 +88,6 @@ class DatasetRetrieval: if ModelFeature.TOOL_CALL in features \ or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER - - dataset_retriever_tools = self.to_dataset_retriever_tool( - tenant_id=tenant_id, - dataset_ids=dataset_ids, - retrieve_config=retrieve_config, - return_resource=show_retrieve_source, - invoke_from=invoke_from, - hit_callback=hit_callback - ) - - if len(dataset_retriever_tools) == 0: - return None - - agent_configuration = AgentConfiguration( - strategy=planning_strategy, - model_config=model_config, - tools=dataset_retriever_tools, - memory=memory, - max_iterations=10, - max_execution_time=400.0, - early_stopping_method="generate" - ) - - agent_executor = AgentExecutor(agent_configuration) - - should_use_agent = agent_executor.should_use_agent(query) - if not should_use_agent: - return None - - result = agent_executor.run(query) - - return result.output - - def to_dataset_retriever_tool(self, tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[list[BaseTool]]: - """ - A dataset tool is a tool that can be used to retrieve information from a dataset - :param tenant_id: tenant id - :param dataset_ids: dataset ids - :param retrieve_config: retrieve config - :param return_resource: return resource - :param invoke_from: invoke from - :param hit_callback: hit callback - """ - tools = [] available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id @@ -127,56 +106,270 @@ class DatasetRetrieval: continue available_datasets.append(dataset) - + all_documents = [] + user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: - # get retrieval model config - default_retrieval_model = { - 'search_method': 'semantic_search', - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False - } + all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query, + model_instance, + model_config, planning_strategy) + elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from, + available_datasets, query, retrieve_config.top_k, + retrieve_config.score_threshold, + retrieve_config.reranking_model.get('reranking_provider_name'), + retrieve_config.reranking_model.get('reranking_model_name')) - for dataset in available_datasets: + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if show_retrieve_source: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': invoke_from.to_source(), + 'score': document_score_list.get(segment.index_node_id, None) + } + + if invoke_from.to_source() == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + if hit_callback: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + return '' + + def single_retrieve(self, app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + model_instance: ModelInstance, + model_config: ModelConfigWithCredentialsEntity, + planning_strategy: PlanningStrategy, + ): + tools = [] + for dataset in available_datasets: + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + message_tool = PromptMessageTool( + name=dataset.id, + description=description, + parameters={ + "type": "object", + "properties": {}, + "required": [], + } + ) + tools.append(message_tool) + dataset_id = None + if planning_strategy == PlanningStrategy.REACT_ROUTER: + react_multi_dataset_router = ReactMultiDatasetRouter() + dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, + user_id, tenant_id) + + elif planning_strategy == PlanningStrategy.ROUTER: + function_call_router = FunctionCallMultiDatasetRouter() + dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) + + if dataset_id: + # get retrieval model config + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + if dataset: retrieval_model_config = dataset.retrieval_model \ if dataset.retrieval_model else default_retrieval_model # get top k top_k = retrieval_model_config['top_k'] - + # get retrieval method + if dataset.indexing_technique == "economy": + retrival_method = 'keyword_search' + else: + retrival_method = retrieval_model_config['search_method'] + # get reranking model + reranking_model = retrieval_model_config['reranking_model'] \ + if retrieval_model_config['reranking_enable'] else None # get score threshold - score_threshold = None + score_threshold = .0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") - tool = DatasetRetrieverTool.from_dataset( - dataset=dataset, - top_k=top_k, - score_threshold=score_threshold, - hit_callbacks=[hit_callback], - return_resource=return_resource, - retriever_from=invoke_from.to_source() - ) + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, + query=query, + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + self._on_query(query, [dataset_id], app_id, user_from, user_id) + if results: + self._on_retrival_end(results) + return results + return [] - tools.append(tool) - elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: - tool = DatasetMultiRetrieverTool.from_dataset( - dataset_ids=[dataset.id for dataset in available_datasets], - tenant_id=tenant_id, - top_k=retrieve_config.top_k or 2, - score_threshold=retrieve_config.score_threshold, - hit_callbacks=[hit_callback], - return_resource=return_resource, - retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), - reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') + def multiple_retrieve(self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + top_k: int, + score_threshold: float, + reranking_provider_name: str, + reranking_model_name: str): + threads = [] + all_documents = [] + dataset_ids = [dataset.id for dataset in available_datasets] + for dataset in available_datasets: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset.id, + 'query': query, + 'top_k': top_k, + 'all_documents': all_documents, + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_provider_name, + model_type=ModelType.RERANK, + model=reranking_model_name + ) + + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(query, all_documents, + score_threshold, + top_k) + self._on_query(query, dataset_ids, app_id, user_from, user_id) + if all_documents: + self._on_retrival_end(all_documents) + return all_documents + + def _on_retrival_end(self, documents: list[Document]) -> None: + """Handle retrival end.""" + for document in documents: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata['doc_id'] ) - tools.append(tool) + # if 'dataset_id' in document.metadata: + if 'dataset_id' in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) - return tools + # add hit count to document segment + query.update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False + ) + + db.session.commit() + + def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None: + """ + Handle query. + """ + if not query: + return + for dataset_id in dataset_ids: + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=query, + source='app', + source_app_id=app_id, + created_by_role=user_from, + created_by=user_id + ) + db.session.add(dataset_query) + db.session.commit() + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=top_k + ) + if documents: + all_documents.extend(documents) + else: + if top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) + + all_documents.extend(documents) diff --git a/api/core/rag/retrieval/agent/__init__.py b/api/core/rag/retrieval/output_parser/__init__.py similarity index 100% rename from api/core/rag/retrieval/agent/__init__.py rename to api/core/rag/retrieval/output_parser/__init__.py diff --git a/api/core/rag/retrieval/agent/output_parser/structured_chat.py b/api/core/rag/retrieval/output_parser/structured_chat.py similarity index 100% rename from api/core/rag/retrieval/agent/output_parser/structured_chat.py rename to api/core/rag/retrieval/output_parser/structured_chat.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py similarity index 100% rename from api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py rename to api/core/rag/retrieval/router/multi_dataset_function_call_router.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py similarity index 90% rename from api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py rename to api/core/rag/retrieval/router/multi_dataset_react_route.py index a2e3cd71a5..0ec01047b3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -12,8 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage -from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser from core.workflow.nodes.llm.llm_node import LLMNode FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). @@ -55,11 +54,10 @@ class ReactMultiDatasetRouter: self, query: str, dataset_tools: list[PromptMessageTool], - node_data: KnowledgeRetrievalNodeData, model_config: ModelConfigWithCredentialsEntity, model_instance: ModelInstance, user_id: str, - tenant_id: str, + tenant_id: str ) -> Union[str, None]: """Given input, decided what to do. @@ -72,7 +70,8 @@ class ReactMultiDatasetRouter: return dataset_tools[0].name try: - return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance, + return self._react_invoke(query=query, model_config=model_config, + model_instance=model_instance, tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) except Exception as e: return None @@ -80,7 +79,6 @@ class ReactMultiDatasetRouter: def _react_invoke( self, query: str, - node_data: KnowledgeRetrievalNodeData, model_config: ModelConfigWithCredentialsEntity, model_instance: ModelInstance, tools: Sequence[PromptMessageTool], @@ -121,7 +119,7 @@ class ReactMultiDatasetRouter: model_config=model_config ) result_text, usage = self._invoke_llm( - node_data=node_data, + completion_param=model_config.parameters, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -134,10 +132,11 @@ class ReactMultiDatasetRouter: return agent_decision.tool return None - def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData, + def _invoke_llm(self, completion_param: dict, model_instance: ModelInstance, prompt_messages: list[PromptMessage], - stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]: + stop: list[str], user_id: str, tenant_id: str + ) -> tuple[str, LLMUsage]: """ Invoke large language model :param node_data: node data @@ -148,7 +147,7 @@ class ReactMultiDatasetRouter: """ invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=node_data.single_retrieval_config.model.completion_params, + model_parameters=completion_param, stop=stop, stream=True, user=user_id, @@ -203,7 +202,8 @@ class ReactMultiDatasetRouter: ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: - tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") + tool_strings.append( + f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") formatted_tools = "\n".join(tool_strings) unique_tool_names = set(tool.name for tool in tools) tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index a1dce8b09d..be3cec9152 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,28 +1,21 @@ -import threading from typing import Any, cast -from flask import Flask, current_app - from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.rag.datasource.retrieval_service import RetrievalService -from core.rerank.rerank import RerankRunner +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter -from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter from extensions.ext_database import db -from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus default_retrieval_model = { @@ -106,10 +99,45 @@ class KnowledgeRetrievalNode(BaseNode): available_datasets.append(dataset) all_documents = [] + dataset_retrieval = DatasetRetrieval() if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: - all_documents = self._single_retrieve(available_datasets, node_data, query) + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + # check model is support tool calling + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + # get model schema + model_schema = model_type_instance.get_model_schema( + model=model_config.model, + credentials=model_config.credentials + ) + + if model_schema: + planning_strategy = PlanningStrategy.REACT_ROUTER + features = model_schema.features + if features: + if ModelFeature.TOOL_CALL in features \ + or ModelFeature.MULTI_TOOL_CALL in features: + planning_strategy = PlanningStrategy.ROUTER + all_documents = dataset_retrieval.single_retrieve( + available_datasets=available_datasets, + tenant_id=self.tenant_id, + user_id=self.user_id, + app_id=self.app_id, + user_from=self.user_from.value, + query=query, + model_config=model_config, + model_instance=model_instance, + planning_strategy=planning_strategy + ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: - all_documents = self._multiple_retrieve(available_datasets, node_data, query) + all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, + self.user_from.value, + available_datasets, query, + node_data.multiple_retrieval_config.top_k, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.reranking_model.provider, + node_data.multiple_retrieval_config.reranking_model.model) context_list = [] if all_documents: @@ -184,87 +212,6 @@ class KnowledgeRetrievalNode(BaseNode): variable_mapping['query'] = node_data.query_variable_selector return variable_mapping - def _single_retrieve(self, available_datasets, node_data, query): - tools = [] - for dataset in available_datasets: - description = dataset.description - if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name - - description = description.replace('\n', '').replace('\r', '') - message_tool = PromptMessageTool( - name=dataset.id, - description=description, - parameters={ - "type": "object", - "properties": {}, - "required": [], - } - ) - tools.append(message_tool) - # fetch model config - model_instance, model_config = self._fetch_model_config(node_data) - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - # get model schema - model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials - ) - - if not model_schema: - return None - planning_strategy = PlanningStrategy.REACT_ROUTER - features = model_schema.features - if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: - planning_strategy = PlanningStrategy.ROUTER - dataset_id = None - if planning_strategy == PlanningStrategy.REACT_ROUTER: - react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance, - self.user_id, self.tenant_id) - - elif planning_strategy == PlanningStrategy.ROUTER: - function_call_router = FunctionCallMultiDatasetRouter() - dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) - if dataset_id: - # get retrieval model config - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() - if dataset: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model - - # get top k - top_k = retrieval_model_config['top_k'] - # get retrieval method - if dataset.indexing_technique == "economy": - retrival_method = 'keyword_search' - else: - retrival_method = retrieval_model_config['search_method'] - # get reranking model - reranking_model=retrieval_model_config['reranking_model'] \ - if retrieval_model_config['reranking_enable'] else None - # get score threshold - score_threshold = .0 - score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") - if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") - - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, - query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) - self._on_query(query, [dataset_id]) - if results: - self._on_retrival_end(results) - return results - return [] - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ ModelInstance, ModelConfigWithCredentialsEntity]: """ @@ -335,112 +282,3 @@ class KnowledgeRetrievalNode(BaseNode): parameters=completion_params, stop=stop, ) - - def _multiple_retrieve(self, available_datasets, node_data, query): - threads = [] - all_documents = [] - dataset_ids = [dataset.id for dataset in available_datasets] - for dataset in available_datasets: - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset.id, - 'query': query, - 'top_k': node_data.multiple_retrieval_config.top_k, - 'all_documents': all_documents, - }) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - # do rerank for searched documents - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider=node_data.multiple_retrieval_config.reranking_model.provider, - model_type=ModelType.RERANK, - model=node_data.multiple_retrieval_config.reranking_model.model - ) - - rerank_runner = RerankRunner(rerank_model_instance) - all_documents = rerank_runner.run(query, all_documents, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.top_k) - self._on_query(query, dataset_ids) - if all_documents: - self._on_retrival_end(all_documents) - return all_documents - - def _on_retrival_end(self, documents: list[Document]) -> None: - """Handle retrival end.""" - for document in documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] - ) - - # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) - - # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) - - db.session.commit() - - def _on_query(self, query: str, dataset_ids: list[str]) -> None: - """ - Handle query. - """ - if not query: - return - for dataset_id in dataset_ids: - dataset_query = DatasetQuery( - dataset_id=dataset_id, - content=query, - source='app', - source_app_id=self.app_id, - created_by_role=self.user_from.value, - created_by=self.user_id - ) - db.session.add(dataset_query) - db.session.commit() - - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): - with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() - - if not dataset: - return [] - - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve(retrival_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=top_k - ) - if documents: - all_documents.extend(documents) - else: - if top_k > 0: - # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=retrieval_model['score_threshold'] - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None - ) - - all_documents.extend(documents) -