From c7bddb637b1c2641be1dc8498e69c9c0f41a7469 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 4 Jun 2024 20:07:54 +0800 Subject: [PATCH] support instruction in classifier node (#4913) --- .../question_classifier_node.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index af1e68b92a..1f59242e98 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -12,6 +12,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool @@ -26,6 +27,7 @@ from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, ) +from core.workflow.utils.variable_template_parser import VariableTemplateParser from libs.json_in_md_parser import parse_and_check_json_markdown from models.workflow import WorkflowNodeExecutionStatus @@ -47,6 +49,9 @@ class QuestionClassifierNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + # fetch instruction + instruction = self._format_instruction(node_data.instruction, variable_pool) + node_data.instruction = instruction # fetch prompt messages prompt_messages, stop = self._fetch_prompt( node_data=node_data, @@ -122,6 +127,12 @@ class QuestionClassifierNode(LLMNode): node_data = node_data node_data = cast(cls._node_data_cls, node_data) variable_mapping = {'query': node_data.query_variable_selector} + variable_selectors = [] + if node_data.instruction: + variable_template_parser = VariableTemplateParser(template=node_data.instruction) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + for variable_selector in variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector return variable_mapping @classmethod @@ -269,8 +280,30 @@ class QuestionClassifierNode(LLMNode): text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, input_text=input_text, categories=json.dumps(categories), - classification_instructions=instruction, ensure_ascii=False) + classification_instructions=instruction, + ensure_ascii=False) ) else: raise ValueError(f"Model mode {model_mode} not support.") + + def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str: + inputs = {} + + variable_selectors = [] + variable_template_parser = VariableTemplateParser(template=instruction) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + for variable_selector in variable_selectors: + variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + if variable_value is None: + raise ValueError(f'Variable {variable_selector.variable} not found') + + inputs[variable_selector.variable] = variable_value + + prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + instruction = prompt_template.format( + prompt_inputs + ) + return instruction