diff --git a/api/core/helper/code_executor/jinja2_formatter.py b/api/core/helper/code_executor/jinja2_formatter.py new file mode 100644 index 0000000000..96f35e3ab2 --- /dev/null +++ b/api/core/helper/code_executor/jinja2_formatter.py @@ -0,0 +1,17 @@ +from core.helper.code_executor.code_executor import CodeExecutor + + +class Jinja2Formatter: + @classmethod + def format(cls, template: str, inputs: str) -> str: + """ + Format template + :param template: template + :param inputs: inputs + :return: + """ + result = CodeExecutor.execute_workflow_code_template( + language='jinja2', code=template, inputs=inputs + ) + + return result['result'] \ No newline at end of file diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 29b516ac02..9952371a82 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -2,6 +2,7 @@ from typing import Optional, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file.file_obj import FileVar +from core.helper.code_executor.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = [] - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + if prompt_template.edition_type == 'basic' or not prompt_template.edition_type: + prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - if memory and memory_config: - role_prefix = memory_config.role_prefix - prompt_inputs = self._set_histories_variable( - memory=memory, - memory_config=memory_config, - raw_prompt=raw_prompt, - role_prefix=role_prefix, - prompt_template=prompt_template, - prompt_inputs=prompt_inputs, - model_config=model_config + if memory and memory_config: + role_prefix = memory_config.role_prefix + prompt_inputs = self._set_histories_variable( + memory=memory, + memory_config=memory_config, + raw_prompt=raw_prompt, + role_prefix=role_prefix, + prompt_template=prompt_template, + prompt_inputs=prompt_inputs, + model_config=model_config + ) + + if query: + prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + + prompt = prompt_template.format( + prompt_inputs ) + else: + prompt = raw_prompt + prompt_inputs = inputs - if query: - prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) - - prompt = prompt_template.format( - prompt_inputs - ) + prompt = Jinja2Formatter.format(prompt, prompt_inputs) if files: prompt_message_contents = [TextPromptMessageContent(data=prompt)] @@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform): for prompt_item in raw_prompt_list: raw_prompt = prompt_item.text - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + if prompt_item.edition_type == 'basic' or not prompt_item.edition_type: + prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) + prompt = prompt_template.format( + prompt_inputs + ) + elif prompt_item.edition_type == 'jinja2': + prompt = raw_prompt + prompt_inputs = inputs + + prompt = Jinja2Formatter.format(prompt, prompt_inputs) + else: + raise ValueError(f'Invalid edition type: {prompt_item.edition_type}') if prompt_item.role == PromptMessageRole.USER: prompt_messages.append(UserPromptMessage(content=prompt)) diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 2be00bdf0e..23a8602bea 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel @@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel): """ text: str role: PromptMessageRole + edition_type: Optional[Literal['basic', 'jinja2']] class CompletionModelPromptTemplate(BaseModel): @@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel): Completion Model Prompt Template. """ text: str + edition_type: Optional[Literal['basic', 'jinja2']] class MemoryConfig(BaseModel): diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index c390aaf8c9..1e48a10bc7 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector class ModelConfig(BaseModel): @@ -37,13 +38,31 @@ class VisionConfig(BaseModel): enabled: bool configs: Optional[Configs] = None +class PromptConfig(BaseModel): + """ + Prompt Config. + """ + jinja2_variables: Optional[list[VariableSelector]] = None + +class LLMNodeChatModelMessage(ChatModelMessage): + """ + LLM Node Chat Model Message. + """ + jinja2_text: Optional[str] = None + +class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): + """ + LLM Node Chat Model Prompt Template. + """ + jinja2_text: Optional[str] = None class LLMNodeData(BaseNodeData): """ LLM Node Data. """ model: ModelConfig - prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate] + prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] + prompt_config: Optional[PromptConfig] = None memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index c8b7f279ab..fef09c1385 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,4 +1,6 @@ +import json from collections.abc import Generator +from copy import deepcopy from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig +from core.workflow.nodes.llm.entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, +) from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation @@ -39,16 +45,24 @@ class LLMNode(BaseNode): :param variable_pool: variable pool :return: """ - node_data = self.node_data - node_data = cast(self._node_data_cls, node_data) + node_data = cast(LLMNodeData, deepcopy(self.node_data)) node_inputs = None process_data = None try: + # init messages template + node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template) + # fetch variables and fetch values from variable pool inputs = self._fetch_inputs(node_data, variable_pool) + # fetch jinja2 inputs + jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool) + + # merge inputs + inputs.update(jinja_inputs) + node_inputs = {} # fetch files @@ -183,6 +197,86 @@ class LLMNode(BaseNode): usage = LLMUsage.empty_usage() return full_text, usage + + def _transform_chat_messages(self, + messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + """ + Transform chat messages + + :param messages: chat messages + :return: + """ + + if isinstance(messages, LLMNodeCompletionModelPromptTemplate): + if messages.edition_type == 'jinja2': + messages.text = messages.jinja2_text + + return messages + + for message in messages: + if message.edition_type == 'jinja2': + message.text = message.jinja2_text + + return messages + + def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: + """ + Fetch jinja inputs + :param node_data: node data + :param variable_pool: variable pool + :return: + """ + variables = {} + + if not node_data.prompt_config: + return variables + + for variable_selector in node_data.prompt_config.jinja2_variables or []: + variable = variable_selector.variable + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + + def parse_dict(d: dict) -> str: + """ + Parse dict into string + """ + # check if it's a context structure + if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: + return d['content'] + + # else, parse the dict + try: + return json.dumps(d, ensure_ascii=False) + except Exception: + return str(d) + + if isinstance(value, str): + value = value + elif isinstance(value, list): + result = '' + for item in value: + if isinstance(item, dict): + result += parse_dict(item) + elif isinstance(item, str): + result += item + elif isinstance(item, int | float): + result += str(item) + else: + result += str(item) + result += '\n' + value = result.strip() + elif isinstance(value, dict): + value = parse_dict(value) + elif isinstance(value, int | float): + value = str(value) + else: + value = str(value) + + variables[variable] = value + + return variables def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: """ @@ -531,25 +625,25 @@ class LLMNode(BaseNode): db.session.commit() @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) prompt_template = node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + if prompt.edition_type != 'jinja2': + variable_template_parser = VariableTemplateParser(template=prompt.text) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) else: - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() + if prompt_template.edition_type != 'jinja2': + variable_template_parser = VariableTemplateParser(template=prompt_template.text) + variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} for variable_selector in variable_selectors: @@ -571,6 +665,22 @@ class LLMNode(BaseNode): if node_data.memory: variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] + if node_data.prompt_config: + enable_jinja = False + + if isinstance(prompt_template, list): + for prompt in prompt_template: + if prompt.edition_type == 'jinja2': + enable_jinja = True + break + else: + if prompt_template.edition_type == 'jinja2': + enable_jinja = True + + if enable_jinja: + for variable_selector in node_data.prompt_config.jinja2_variables or []: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + return variable_mapping @classmethod @@ -588,7 +698,8 @@ class LLMNode(BaseNode): "prompts": [ { "role": "system", - "text": "You are a helpful AI assistant." + "text": "You are a helpful AI assistant.", + "edition_type": "basic" } ] }, @@ -600,7 +711,8 @@ class LLMNode(BaseNode): "prompt": { "text": "Here is the chat histories between human and assistant, inside " " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:" + "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", + "edition_type": "basic" }, "stop": ["Human:"] } diff --git a/api/requirements-dev.txt b/api/requirements-dev.txt index 0391ac5969..70b2ce2ef5 100644 --- a/api/requirements-dev.txt +++ b/api/requirements-dev.txt @@ -3,3 +3,4 @@ pytest~=8.1.1 pytest-benchmark~=4.0.0 pytest-env~=1.1.3 pytest-mock~=3.14.0 +jinja2~=3.1.2 \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py index 38517cf448..ef84c92625 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py @@ -3,6 +3,7 @@ from typing import Literal import pytest from _pytest.monkeypatch import MonkeyPatch +from jinja2 import Template from core.helper.code_executor.code_executor import CodeExecutor @@ -18,7 +19,7 @@ class MockedCodeExecutor: } elif language == 'jinja2': return { - "result": "3" + "result": Template(code).render(inputs) } @pytest.fixture diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 8a8a58d59f..d04497a187 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -1,3 +1,4 @@ +import json import os from unittest.mock import MagicMock @@ -19,6 +20,7 @@ from models.workflow import WorkflowNodeExecutionStatus """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) @@ -116,3 +118,118 @@ def test_execute_llm(setup_openai_mock): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs['text'] is not None assert result.outputs['usage']['total_tokens'] > 0 + +@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True) +@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) +def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): + """ + Test execute LLM node with jinja2 + """ + node = LLMNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'llm', + 'model': { + 'provider': 'openai', + 'name': 'gpt-3.5-turbo', + 'mode': 'chat', + 'completion_params': {} + }, + 'prompt_config': { + 'jinja2_variables': [{ + 'variable': 'sys_query', + 'value_selector': ['sys', 'query'] + }, { + 'variable': 'output', + 'value_selector': ['abc', 'output'] + }] + }, + 'prompt_template': [ + { + 'role': 'system', + 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}', + 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.', + 'edition_type': 'jinja2' + }, + { + 'role': 'user', + 'text': '{{#sys.query#}}', + 'jinja2_text': '{{sys_query}}', + 'edition_type': 'basic' + } + ], + 'memory': None, + 'context': { + 'enabled': False + }, + 'vision': { + 'enabled': False + } + } + } + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather today?', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') + + credentials = { + 'openai_api_key': os.environ.get('OPENAI_API_KEY') + } + + provider_instance = ModelProviderFactory().get_provider_instance('openai') + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id='1', + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration( + enabled=False + ), + custom_configuration=CustomConfiguration( + provider=CustomProviderConfiguration( + credentials=credentials + ) + ) + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + + model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') + + model_config = ModelConfigWithCredentialsEntity( + model='gpt-3.5-turbo', + provider='openai', + mode='chat', + credentials=credentials, + parameters={}, + model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'), + provider_model_bundle=provider_model_bundle + ) + + # Mock db.session.close() + db.session.close = MagicMock() + + node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) + + # execute node + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert 'sunny' in json.dumps(result.process_data) + assert 'what\'s the weather today?' in json.dumps(result.process_data) \ No newline at end of file