mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 21:29:02 +08:00
support instruction in classifier node (#4913)
This commit is contained in:
parent
4e3b0c5aea
commit
c7bddb637b
@ -12,6 +12,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||||
from core.prompt.simple_prompt_transform import ModelMode
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
@ -26,6 +27,7 @@ from core.workflow.nodes.question_classifier.template_prompts import (
|
|||||||
QUESTION_CLASSIFIER_USER_PROMPT_2,
|
QUESTION_CLASSIFIER_USER_PROMPT_2,
|
||||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||||
)
|
)
|
||||||
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
@ -47,6 +49,9 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||||
# fetch memory
|
# fetch memory
|
||||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||||
|
# fetch instruction
|
||||||
|
instruction = self._format_instruction(node_data.instruction, variable_pool)
|
||||||
|
node_data.instruction = instruction
|
||||||
# fetch prompt messages
|
# fetch prompt messages
|
||||||
prompt_messages, stop = self._fetch_prompt(
|
prompt_messages, stop = self._fetch_prompt(
|
||||||
node_data=node_data,
|
node_data=node_data,
|
||||||
@ -122,6 +127,12 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
node_data = node_data
|
node_data = node_data
|
||||||
node_data = cast(cls._node_data_cls, node_data)
|
node_data = cast(cls._node_data_cls, node_data)
|
||||||
variable_mapping = {'query': node_data.query_variable_selector}
|
variable_mapping = {'query': node_data.query_variable_selector}
|
||||||
|
variable_selectors = []
|
||||||
|
if node_data.instruction:
|
||||||
|
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||||
|
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||||
|
for variable_selector in variable_selectors:
|
||||||
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -269,8 +280,30 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
|
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
|
||||||
input_text=input_text,
|
input_text=input_text,
|
||||||
categories=json.dumps(categories),
|
categories=json.dumps(categories),
|
||||||
classification_instructions=instruction, ensure_ascii=False)
|
classification_instructions=instruction,
|
||||||
|
ensure_ascii=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model mode {model_mode} not support.")
|
raise ValueError(f"Model mode {model_mode} not support.")
|
||||||
|
|
||||||
|
def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
|
||||||
|
inputs = {}
|
||||||
|
|
||||||
|
variable_selectors = []
|
||||||
|
variable_template_parser = VariableTemplateParser(template=instruction)
|
||||||
|
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||||
|
for variable_selector in variable_selectors:
|
||||||
|
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||||
|
if variable_value is None:
|
||||||
|
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||||
|
|
||||||
|
inputs[variable_selector.variable] = variable_value
|
||||||
|
|
||||||
|
prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True)
|
||||||
|
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||||
|
|
||||||
|
instruction = prompt_template.format(
|
||||||
|
prompt_inputs
|
||||||
|
)
|
||||||
|
return instruction
|
||||||
|
Loading…
x
Reference in New Issue
Block a user