diff --git a/api/core/chain/llm_router_chain.py b/api/core/chain/llm_router_chain.py index 432cb4b306..3108561e6b 100644 --- a/api/core/chain/llm_router_chain.py +++ b/api/core/chain/llm_router_chain.py @@ -84,13 +84,16 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]): def parse_json_markdown(self, json_string: str) -> dict: # Remove the triple backticks if present - json_string = json_string.replace("```json", "").replace("```", "") + start_index = json_string.find("```json") + end_index = json_string.find("```", start_index + len("```json")) - # Strip whitespace and newlines from the start and end - json_string = json_string.strip() + if start_index != -1 and end_index != -1: + extracted_content = json_string[start_index + len("```json"):end_index].strip() - # Parse the JSON string into a Python dictionary - parsed = json.loads(json_string) + # Parse the JSON string into a Python dictionary + parsed = json.loads(extracted_content) + else: + raise Exception("Could not find JSON block in the output.") return parsed diff --git a/api/core/chain/multi_dataset_router_chain.py b/api/core/chain/multi_dataset_router_chain.py index 736bfaa1aa..365f68eaf6 100644 --- a/api/core/chain/multi_dataset_router_chain.py +++ b/api/core/chain/multi_dataset_router_chain.py @@ -90,7 +90,7 @@ class MultiDatasetRouterChain(Chain): callback_manager=llm_callback_manager ) - destinations = [f"{d.id}: {d.description}" for d in datasets] + destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ')) for d in datasets] destinations_str = "\n".join(destinations) router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( destinations=destinations_str