diff --git a/api/core/chain/llm_router_chain.py b/api/core/chain/llm_router_chain.py new file mode 100644 index 0000000000..432cb4b306 --- /dev/null +++ b/api/core/chain/llm_router_chain.py @@ -0,0 +1,132 @@ +"""Base classes for LLM-powered router chains.""" +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional, Type, cast, NamedTuple + +from langchain.chains.base import Chain +from pydantic import root_validator + +from langchain.chains import LLMChain +from langchain.prompts import BasePromptTemplate +from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel + + +class Route(NamedTuple): + destination: Optional[str] + next_inputs: Dict[str, Any] + + +class LLMRouterChain(Chain): + """A router chain that uses an LLM chain to perform routing.""" + + llm_chain: LLMChain + """LLM chain used to perform routing""" + + @root_validator() + def validate_prompt(cls, values: dict) -> dict: + prompt = values["llm_chain"].prompt + if prompt.output_parser is None: + raise ValueError( + "LLMRouterChain requires base llm_chain prompt to have an output" + " parser that converts LLM text output to a dictionary with keys" + " 'destination' and 'next_inputs'. Received a prompt with no output" + " parser." + ) + return values + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the LLM chain prompt expects. + + :meta private: + """ + return self.llm_chain.input_keys + + def _validate_outputs(self, outputs: Dict[str, Any]) -> None: + super()._validate_outputs(outputs) + if not isinstance(outputs["next_inputs"], dict): + raise ValueError + + def _call( + self, + inputs: Dict[str, Any] + ) -> Dict[str, Any]: + output = cast( + Dict[str, Any], + self.llm_chain.predict_and_parse(**inputs), + ) + return output + + @classmethod + def from_llm( + cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any + ) -> LLMRouterChain: + """Convenience constructor.""" + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) + + @property + def output_keys(self) -> List[str]: + return ["destination", "next_inputs"] + + def route(self, inputs: Dict[str, Any]) -> Route: + result = self(inputs) + return Route(result["destination"], result["next_inputs"]) + + +class RouterOutputParser(BaseOutputParser[Dict[str, str]]): + """Parser for output of router chain int he multi-prompt chain.""" + + default_destination: str = "DEFAULT" + next_inputs_type: Type = str + next_inputs_inner_key: str = "input" + + def parse_json_markdown(self, json_string: str) -> dict: + # Remove the triple backticks if present + json_string = json_string.replace("```json", "").replace("```", "") + + # Strip whitespace and newlines from the start and end + json_string = json_string.strip() + + # Parse the JSON string into a Python dictionary + parsed = json.loads(json_string) + + return parsed + + def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict: + try: + json_obj = self.parse_json_markdown(text) + except json.JSONDecodeError as e: + raise OutputParserException(f"Got invalid JSON object. Error: {e}") + for key in expected_keys: + if key not in json_obj: + raise OutputParserException( + f"Got invalid return object. Expected key `{key}` " + f"to be present, but got {json_obj}" + ) + return json_obj + + def parse(self, text: str) -> Dict[str, Any]: + try: + expected_keys = ["destination", "next_inputs"] + parsed = self.parse_and_check_json_markdown(text, expected_keys) + if not isinstance(parsed["destination"], str): + raise ValueError("Expected 'destination' to be a string.") + if not isinstance(parsed["next_inputs"], self.next_inputs_type): + raise ValueError( + f"Expected 'next_inputs' to be {self.next_inputs_type}." + ) + parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]} + if ( + parsed["destination"].strip().lower() + == self.default_destination.lower() + ): + parsed["destination"] = None + else: + parsed["destination"] = parsed["destination"].strip() + return parsed + except Exception as e: + raise OutputParserException( + f"Parsing text\n{text}\n raised following error:\n{e}" + ) diff --git a/api/core/chain/main_chain_builder.py b/api/core/chain/main_chain_builder.py index 5a4ab2214d..4cb6205fcb 100644 --- a/api/core/chain/main_chain_builder.py +++ b/api/core/chain/main_chain_builder.py @@ -1,18 +1,18 @@ from typing import Optional, List -from langchain.callbacks import SharedCallbackManager +from langchain.callbacks import SharedCallbackManager, CallbackManager from langchain.chains import SequentialChain from langchain.chains.base import Chain from langchain.memory.chat_memory import BaseChatMemory -from core.agent.agent_builder import AgentBuilder from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler -from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler +from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.chain.chain_builder import ChainBuilder -from core.constant import llm_constant +from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain from core.conversation_message_task import ConversationMessageTask -from core.tool.dataset_tool_builder import DatasetToolBuilder +from extensions.ext_database import db +from models.dataset import Dataset class MainChainBuilder: @@ -31,8 +31,7 @@ class MainChainBuilder: tenant_id=tenant_id, agent_mode=agent_mode, memory=memory, - dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task), - agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler + conversation_message_task=conversation_message_task ) chains += tool_chains @@ -59,15 +58,15 @@ class MainChainBuilder: @classmethod def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], - dataset_tool_callback_handler: DatasetToolCallbackHandler, - agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): + conversation_message_task: ConversationMessageTask): # agent mode chains = [] if agent_mode and agent_mode.get('enabled'): tools = agent_mode.get('tools', []) pre_fixed_chains = [] - agent_tools = [] + # agent_tools = [] + datasets = [] for tool in tools: tool_type = list(tool.keys())[0] tool_config = list(tool.values())[0] @@ -76,34 +75,27 @@ class MainChainBuilder: if chain: pre_fixed_chains.append(chain) elif tool_type == "dataset": - dataset_tool = DatasetToolBuilder.build_dataset_tool( - tenant_id=tenant_id, - dataset_id=tool_config.get("id"), - response_mode='no_synthesizer', # "compact" - callback_handler=dataset_tool_callback_handler - ) + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == tool_config.get("id") + ).first() - if dataset_tool: - agent_tools.append(dataset_tool) + if dataset: + datasets.append(dataset) # add pre-fixed chains chains += pre_fixed_chains - if len(agent_tools) == 1: + if len(datasets) > 0: # tool to chain - tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output') - chains.append(tool_chain) - elif len(agent_tools) > 1: - # build agent config - agent_chain = AgentBuilder.to_agent_chain( + multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets( tenant_id=tenant_id, - tools=agent_tools, - memory=memory, - dataset_tool_callback_handler=dataset_tool_callback_handler, - agent_loop_gather_callback_handler=agent_loop_gather_callback_handler + datasets=datasets, + conversation_message_task=conversation_message_task, + callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) ) - - chains.append(agent_chain) + chains.append(multi_dataset_router_chain) final_output_key = cls.get_chains_output_key(chains) diff --git a/api/core/chain/multi_dataset_router_chain.py b/api/core/chain/multi_dataset_router_chain.py new file mode 100644 index 0000000000..736bfaa1aa --- /dev/null +++ b/api/core/chain/multi_dataset_router_chain.py @@ -0,0 +1,138 @@ +from typing import Mapping, List, Dict, Any, Optional + +from langchain import LLMChain, PromptTemplate, ConversationChain +from langchain.callbacks import CallbackManager +from langchain.chains.base import Chain +from langchain.schema import BaseLanguageModel +from pydantic import Extra + +from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler +from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler +from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser +from core.conversation_message_task import ConversationMessageTask +from core.llm.llm_builder import LLMBuilder +from core.tool.dataset_tool_builder import DatasetToolBuilder +from core.tool.llama_index_tool import EnhanceLlamaIndexTool +from models.dataset import Dataset + +MULTI_PROMPT_ROUTER_TEMPLATE = """ +Given a raw text input to a language model select the model prompt best suited for \ +the input. You will be given the names of the available prompts and a description of \ +what the prompt is best suited for. You may also revise the original input if you \ +think that revising it will ultimately lead to a better response from the language \ +model. + +<< FORMATTING >> +Return a markdown code snippet with a JSON object formatted to look like: +```json +{{{{ + "destination": string \\ name of the prompt to use or "DEFAULT" + "next_inputs": string \\ a potentially modified version of the original input +}}}} +``` + +REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \ +it can be "DEFAULT" if the input is not well suited for any of the candidate prompts. +REMEMBER: "next_inputs" can just be the original input if you don't think any \ +modifications are needed. + +<< CANDIDATE PROMPTS >> +{destinations} + +<< INPUT >> +{{input}} + +<< OUTPUT >> +""" + + +class MultiDatasetRouterChain(Chain): + """Use a single chain to route an input to one of multiple candidate chains.""" + + router_chain: LLMRouterChain + """Chain for deciding a destination chain and the input to it.""" + dataset_tools: Mapping[str, EnhanceLlamaIndexTool] + """Map of name to candidate chains that inputs can be routed to.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the router chain prompt expects. + + :meta private: + """ + return self.router_chain.input_keys + + @property + def output_keys(self) -> List[str]: + return ["text"] + + @classmethod + def from_datasets( + cls, + tenant_id: str, + datasets: List[Dataset], + conversation_message_task: ConversationMessageTask, + **kwargs: Any, + ): + """Convenience constructor for instantiating from destination prompts.""" + llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) + llm = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name='gpt-3.5-turbo', + temperature=0, + max_tokens=1024, + callback_manager=llm_callback_manager + ) + + destinations = [f"{d.id}: {d.description}" for d in datasets] + destinations_str = "\n".join(destinations) + router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( + destinations=destinations_str + ) + router_prompt = PromptTemplate( + template=router_template, + input_variables=["input"], + output_parser=RouterOutputParser(), + ) + router_chain = LLMRouterChain.from_llm(llm, router_prompt) + dataset_tools = {} + for dataset in datasets: + dataset_tool = DatasetToolBuilder.build_dataset_tool( + dataset=dataset, + response_mode='no_synthesizer', # "compact" + callback_handler=DatasetToolCallbackHandler(conversation_message_task) + ) + dataset_tools[dataset.id] = dataset_tool + return cls( + router_chain=router_chain, + dataset_tools=dataset_tools, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any] + ) -> Dict[str, Any]: + if len(self.dataset_tools) == 0: + return {"text": ''} + elif len(self.dataset_tools) == 1: + return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])} + + route = self.router_chain.route(inputs) + + if not route.destination: + return {"text": ''} + elif route.destination in self.dataset_tools: + return {"text": self.dataset_tools[route.destination].run( + route.next_inputs['input'] + )} + else: + raise ValueError( + f"Received invalid destination chain name '{route.destination}'" + ) diff --git a/api/core/tool/dataset_tool_builder.py b/api/core/tool/dataset_tool_builder.py index b31b15511a..aa7a618b50 100644 --- a/api/core/tool/dataset_tool_builder.py +++ b/api/core/tool/dataset_tool_builder.py @@ -10,24 +10,14 @@ from core.index.keyword_table_index import KeywordTableIndex from core.index.vector_index import VectorIndex from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE from core.tool.llama_index_tool import EnhanceLlamaIndexTool -from extensions.ext_database import db from models.dataset import Dataset class DatasetToolBuilder: @classmethod - def build_dataset_tool(cls, tenant_id: str, dataset_id: str, + def build_dataset_tool(cls, dataset: Dataset, response_mode: str = "no_synthesizer", callback_handler: Optional[DatasetToolCallbackHandler] = None): - # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() - - if not dataset: - return None - if dataset.indexing_technique == "economy": # use keyword table query index = KeywordTableIndex(dataset=dataset).query_index @@ -65,7 +55,7 @@ class DatasetToolBuilder: index_tool_config = IndexToolConfig( index=index, - name=f"dataset-{dataset_id}", + name=f"dataset-{dataset.id}", description=description, index_query_kwargs=query_kwargs, tool_kwargs={ @@ -75,7 +65,7 @@ class DatasetToolBuilder: # return_direct: Whether to return LLM results directly or process the output data with an Output Parser ) - index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id) + index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id) return EnhanceLlamaIndexTool.from_tool_config( tool_config=index_tool_config,