mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 11:15:57 +08:00
feat: support multi datasets router chain mode (#231)
This commit is contained in:
parent
2c23caacd4
commit
88545184be
132
api/core/chain/llm_router_chain.py
Normal file
132
api/core/chain/llm_router_chain.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
"""Base classes for LLM-powered router chains."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.prompts import BasePromptTemplate
|
||||||
|
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class Route(NamedTuple):
|
||||||
|
destination: Optional[str]
|
||||||
|
next_inputs: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRouterChain(Chain):
|
||||||
|
"""A router chain that uses an LLM chain to perform routing."""
|
||||||
|
|
||||||
|
llm_chain: LLMChain
|
||||||
|
"""LLM chain used to perform routing"""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_prompt(cls, values: dict) -> dict:
|
||||||
|
prompt = values["llm_chain"].prompt
|
||||||
|
if prompt.output_parser is None:
|
||||||
|
raise ValueError(
|
||||||
|
"LLMRouterChain requires base llm_chain prompt to have an output"
|
||||||
|
" parser that converts LLM text output to a dictionary with keys"
|
||||||
|
" 'destination' and 'next_inputs'. Received a prompt with no output"
|
||||||
|
" parser."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the LLM chain prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return self.llm_chain.input_keys
|
||||||
|
|
||||||
|
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
super()._validate_outputs(outputs)
|
||||||
|
if not isinstance(outputs["next_inputs"], dict):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
output = cast(
|
||||||
|
Dict[str, Any],
|
||||||
|
self.llm_chain.predict_and_parse(**inputs),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
|
||||||
|
) -> LLMRouterChain:
|
||||||
|
"""Convenience constructor."""
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
return cls(llm_chain=llm_chain, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
return ["destination", "next_inputs"]
|
||||||
|
|
||||||
|
def route(self, inputs: Dict[str, Any]) -> Route:
|
||||||
|
result = self(inputs)
|
||||||
|
return Route(result["destination"], result["next_inputs"])
|
||||||
|
|
||||||
|
|
||||||
|
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
||||||
|
"""Parser for output of router chain int he multi-prompt chain."""
|
||||||
|
|
||||||
|
default_destination: str = "DEFAULT"
|
||||||
|
next_inputs_type: Type = str
|
||||||
|
next_inputs_inner_key: str = "input"
|
||||||
|
|
||||||
|
def parse_json_markdown(self, json_string: str) -> dict:
|
||||||
|
# Remove the triple backticks if present
|
||||||
|
json_string = json_string.replace("```json", "").replace("```", "")
|
||||||
|
|
||||||
|
# Strip whitespace and newlines from the start and end
|
||||||
|
json_string = json_string.strip()
|
||||||
|
|
||||||
|
# Parse the JSON string into a Python dictionary
|
||||||
|
parsed = json.loads(json_string)
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
|
||||||
|
try:
|
||||||
|
json_obj = self.parse_json_markdown(text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||||
|
for key in expected_keys:
|
||||||
|
if key not in json_obj:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Got invalid return object. Expected key `{key}` "
|
||||||
|
f"to be present, but got {json_obj}"
|
||||||
|
)
|
||||||
|
return json_obj
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
expected_keys = ["destination", "next_inputs"]
|
||||||
|
parsed = self.parse_and_check_json_markdown(text, expected_keys)
|
||||||
|
if not isinstance(parsed["destination"], str):
|
||||||
|
raise ValueError("Expected 'destination' to be a string.")
|
||||||
|
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
||||||
|
)
|
||||||
|
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
||||||
|
if (
|
||||||
|
parsed["destination"].strip().lower()
|
||||||
|
== self.default_destination.lower()
|
||||||
|
):
|
||||||
|
parsed["destination"] = None
|
||||||
|
else:
|
||||||
|
parsed["destination"] = parsed["destination"].strip()
|
||||||
|
return parsed
|
||||||
|
except Exception as e:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Parsing text\n{text}\n raised following error:\n{e}"
|
||||||
|
)
|
@ -1,18 +1,18 @@
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from langchain.callbacks import SharedCallbackManager
|
from langchain.callbacks import SharedCallbackManager, CallbackManager
|
||||||
from langchain.chains import SequentialChain
|
from langchain.chains import SequentialChain
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
|
|
||||||
from core.agent.agent_builder import AgentBuilder
|
|
||||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
|
||||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||||
|
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||||
from core.chain.chain_builder import ChainBuilder
|
from core.chain.chain_builder import ChainBuilder
|
||||||
from core.constant import llm_constant
|
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
from core.conversation_message_task import ConversationMessageTask
|
||||||
from core.tool.dataset_tool_builder import DatasetToolBuilder
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
class MainChainBuilder:
|
class MainChainBuilder:
|
||||||
@ -31,8 +31,7 @@ class MainChainBuilder:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
agent_mode=agent_mode,
|
agent_mode=agent_mode,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task),
|
conversation_message_task=conversation_message_task
|
||||||
agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler
|
|
||||||
)
|
)
|
||||||
chains += tool_chains
|
chains += tool_chains
|
||||||
|
|
||||||
@ -59,15 +58,15 @@ class MainChainBuilder:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
conversation_message_task: ConversationMessageTask):
|
||||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
|
||||||
# agent mode
|
# agent mode
|
||||||
chains = []
|
chains = []
|
||||||
if agent_mode and agent_mode.get('enabled'):
|
if agent_mode and agent_mode.get('enabled'):
|
||||||
tools = agent_mode.get('tools', [])
|
tools = agent_mode.get('tools', [])
|
||||||
|
|
||||||
pre_fixed_chains = []
|
pre_fixed_chains = []
|
||||||
agent_tools = []
|
# agent_tools = []
|
||||||
|
datasets = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
tool_type = list(tool.keys())[0]
|
tool_type = list(tool.keys())[0]
|
||||||
tool_config = list(tool.values())[0]
|
tool_config = list(tool.values())[0]
|
||||||
@ -76,34 +75,27 @@ class MainChainBuilder:
|
|||||||
if chain:
|
if chain:
|
||||||
pre_fixed_chains.append(chain)
|
pre_fixed_chains.append(chain)
|
||||||
elif tool_type == "dataset":
|
elif tool_type == "dataset":
|
||||||
dataset_tool = DatasetToolBuilder.build_dataset_tool(
|
# get dataset from dataset id
|
||||||
tenant_id=tenant_id,
|
dataset = db.session.query(Dataset).filter(
|
||||||
dataset_id=tool_config.get("id"),
|
Dataset.tenant_id == tenant_id,
|
||||||
response_mode='no_synthesizer', # "compact"
|
Dataset.id == tool_config.get("id")
|
||||||
callback_handler=dataset_tool_callback_handler
|
).first()
|
||||||
)
|
|
||||||
|
|
||||||
if dataset_tool:
|
if dataset:
|
||||||
agent_tools.append(dataset_tool)
|
datasets.append(dataset)
|
||||||
|
|
||||||
# add pre-fixed chains
|
# add pre-fixed chains
|
||||||
chains += pre_fixed_chains
|
chains += pre_fixed_chains
|
||||||
|
|
||||||
if len(agent_tools) == 1:
|
if len(datasets) > 0:
|
||||||
# tool to chain
|
# tool to chain
|
||||||
tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output')
|
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
|
||||||
chains.append(tool_chain)
|
|
||||||
elif len(agent_tools) > 1:
|
|
||||||
# build agent config
|
|
||||||
agent_chain = AgentBuilder.to_agent_chain(
|
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
tools=agent_tools,
|
datasets=datasets,
|
||||||
memory=memory,
|
conversation_message_task=conversation_message_task,
|
||||||
dataset_tool_callback_handler=dataset_tool_callback_handler,
|
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
|
||||||
agent_loop_gather_callback_handler=agent_loop_gather_callback_handler
|
|
||||||
)
|
)
|
||||||
|
chains.append(multi_dataset_router_chain)
|
||||||
chains.append(agent_chain)
|
|
||||||
|
|
||||||
final_output_key = cls.get_chains_output_key(chains)
|
final_output_key = cls.get_chains_output_key(chains)
|
||||||
|
|
||||||
|
138
api/core/chain/multi_dataset_router_chain.py
Normal file
138
api/core/chain/multi_dataset_router_chain.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
from typing import Mapping, List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from langchain import LLMChain, PromptTemplate, ConversationChain
|
||||||
|
from langchain.callbacks import CallbackManager
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.schema import BaseLanguageModel
|
||||||
|
from pydantic import Extra
|
||||||
|
|
||||||
|
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||||
|
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||||
|
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
|
||||||
|
from core.conversation_message_task import ConversationMessageTask
|
||||||
|
from core.llm.llm_builder import LLMBuilder
|
||||||
|
from core.tool.dataset_tool_builder import DatasetToolBuilder
|
||||||
|
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
MULTI_PROMPT_ROUTER_TEMPLATE = """
|
||||||
|
Given a raw text input to a language model select the model prompt best suited for \
|
||||||
|
the input. You will be given the names of the available prompts and a description of \
|
||||||
|
what the prompt is best suited for. You may also revise the original input if you \
|
||||||
|
think that revising it will ultimately lead to a better response from the language \
|
||||||
|
model.
|
||||||
|
|
||||||
|
<< FORMATTING >>
|
||||||
|
Return a markdown code snippet with a JSON object formatted to look like:
|
||||||
|
```json
|
||||||
|
{{{{
|
||||||
|
"destination": string \\ name of the prompt to use or "DEFAULT"
|
||||||
|
"next_inputs": string \\ a potentially modified version of the original input
|
||||||
|
}}}}
|
||||||
|
```
|
||||||
|
|
||||||
|
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
|
||||||
|
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
|
||||||
|
REMEMBER: "next_inputs" can just be the original input if you don't think any \
|
||||||
|
modifications are needed.
|
||||||
|
|
||||||
|
<< CANDIDATE PROMPTS >>
|
||||||
|
{destinations}
|
||||||
|
|
||||||
|
<< INPUT >>
|
||||||
|
{{input}}
|
||||||
|
|
||||||
|
<< OUTPUT >>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDatasetRouterChain(Chain):
|
||||||
|
"""Use a single chain to route an input to one of multiple candidate chains."""
|
||||||
|
|
||||||
|
router_chain: LLMRouterChain
|
||||||
|
"""Chain for deciding a destination chain and the input to it."""
|
||||||
|
dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
|
||||||
|
"""Map of name to candidate chains that inputs can be routed to."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the router chain prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return self.router_chain.input_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
return ["text"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_datasets(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
datasets: List[Dataset],
|
||||||
|
conversation_message_task: ConversationMessageTask,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Convenience constructor for instantiating from destination prompts."""
|
||||||
|
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
|
||||||
|
llm = LLMBuilder.to_llm(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_name='gpt-3.5-turbo',
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=1024,
|
||||||
|
callback_manager=llm_callback_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
destinations = [f"{d.id}: {d.description}" for d in datasets]
|
||||||
|
destinations_str = "\n".join(destinations)
|
||||||
|
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
|
||||||
|
destinations=destinations_str
|
||||||
|
)
|
||||||
|
router_prompt = PromptTemplate(
|
||||||
|
template=router_template,
|
||||||
|
input_variables=["input"],
|
||||||
|
output_parser=RouterOutputParser(),
|
||||||
|
)
|
||||||
|
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
||||||
|
dataset_tools = {}
|
||||||
|
for dataset in datasets:
|
||||||
|
dataset_tool = DatasetToolBuilder.build_dataset_tool(
|
||||||
|
dataset=dataset,
|
||||||
|
response_mode='no_synthesizer', # "compact"
|
||||||
|
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
|
||||||
|
)
|
||||||
|
dataset_tools[dataset.id] = dataset_tool
|
||||||
|
return cls(
|
||||||
|
router_chain=router_chain,
|
||||||
|
dataset_tools=dataset_tools,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if len(self.dataset_tools) == 0:
|
||||||
|
return {"text": ''}
|
||||||
|
elif len(self.dataset_tools) == 1:
|
||||||
|
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
|
||||||
|
|
||||||
|
route = self.router_chain.route(inputs)
|
||||||
|
|
||||||
|
if not route.destination:
|
||||||
|
return {"text": ''}
|
||||||
|
elif route.destination in self.dataset_tools:
|
||||||
|
return {"text": self.dataset_tools[route.destination].run(
|
||||||
|
route.next_inputs['input']
|
||||||
|
)}
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Received invalid destination chain name '{route.destination}'"
|
||||||
|
)
|
@ -10,24 +10,14 @@ from core.index.keyword_table_index import KeywordTableIndex
|
|||||||
from core.index.vector_index import VectorIndex
|
from core.index.vector_index import VectorIndex
|
||||||
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
|
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
|
||||||
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
class DatasetToolBuilder:
|
class DatasetToolBuilder:
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_dataset_tool(cls, tenant_id: str, dataset_id: str,
|
def build_dataset_tool(cls, dataset: Dataset,
|
||||||
response_mode: str = "no_synthesizer",
|
response_mode: str = "no_synthesizer",
|
||||||
callback_handler: Optional[DatasetToolCallbackHandler] = None):
|
callback_handler: Optional[DatasetToolCallbackHandler] = None):
|
||||||
# get dataset from dataset id
|
|
||||||
dataset = db.session.query(Dataset).filter(
|
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not dataset:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if dataset.indexing_technique == "economy":
|
if dataset.indexing_technique == "economy":
|
||||||
# use keyword table query
|
# use keyword table query
|
||||||
index = KeywordTableIndex(dataset=dataset).query_index
|
index = KeywordTableIndex(dataset=dataset).query_index
|
||||||
@ -65,7 +55,7 @@ class DatasetToolBuilder:
|
|||||||
|
|
||||||
index_tool_config = IndexToolConfig(
|
index_tool_config = IndexToolConfig(
|
||||||
index=index,
|
index=index,
|
||||||
name=f"dataset-{dataset_id}",
|
name=f"dataset-{dataset.id}",
|
||||||
description=description,
|
description=description,
|
||||||
index_query_kwargs=query_kwargs,
|
index_query_kwargs=query_kwargs,
|
||||||
tool_kwargs={
|
tool_kwargs={
|
||||||
@ -75,7 +65,7 @@ class DatasetToolBuilder:
|
|||||||
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
|
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
|
||||||
)
|
)
|
||||||
|
|
||||||
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id)
|
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
|
||||||
|
|
||||||
return EnhanceLlamaIndexTool.from_tool_config(
|
return EnhanceLlamaIndexTool.from_tool_config(
|
||||||
tool_config=index_tool_config,
|
tool_config=index_tool_config,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user