diff --git a/api/core/helper/code_executor/jinja2_formatter.py b/api/core/helper/code_executor/jinja2_formatter.py
new file mode 100644
index 0000000000..96f35e3ab2
--- /dev/null
+++ b/api/core/helper/code_executor/jinja2_formatter.py
@@ -0,0 +1,17 @@
+from core.helper.code_executor.code_executor import CodeExecutor
+
+
+class Jinja2Formatter:
+ @classmethod
+ def format(cls, template: str, inputs: str) -> str:
+ """
+ Format template
+ :param template: template
+ :param inputs: inputs
+ :return:
+ """
+ result = CodeExecutor.execute_workflow_code_template(
+ language='jinja2', code=template, inputs=inputs
+ )
+
+ return result['result']
\ No newline at end of file
diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py
index 29b516ac02..9952371a82 100644
--- a/api/core/prompt/advanced_prompt_transform.py
+++ b/api/core/prompt/advanced_prompt_transform.py
@@ -2,6 +2,7 @@ from typing import Optional, Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar
+from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@@ -80,29 +81,35 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = []
- prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
- prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+ if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
+ prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
- prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+ prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
- if memory and memory_config:
- role_prefix = memory_config.role_prefix
- prompt_inputs = self._set_histories_variable(
- memory=memory,
- memory_config=memory_config,
- raw_prompt=raw_prompt,
- role_prefix=role_prefix,
- prompt_template=prompt_template,
- prompt_inputs=prompt_inputs,
- model_config=model_config
+ if memory and memory_config:
+ role_prefix = memory_config.role_prefix
+ prompt_inputs = self._set_histories_variable(
+ memory=memory,
+ memory_config=memory_config,
+ raw_prompt=raw_prompt,
+ role_prefix=role_prefix,
+ prompt_template=prompt_template,
+ prompt_inputs=prompt_inputs,
+ model_config=model_config
+ )
+
+ if query:
+ prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
+
+ prompt = prompt_template.format(
+ prompt_inputs
)
+ else:
+ prompt = raw_prompt
+ prompt_inputs = inputs
- if query:
- prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
-
- prompt = prompt_template.format(
- prompt_inputs
- )
+ prompt = Jinja2Formatter.format(prompt, prompt_inputs)
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
@@ -135,14 +142,22 @@ class AdvancedPromptTransform(PromptTransform):
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text
- prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
- prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+ if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
+ prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
- prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
+ prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
- prompt = prompt_template.format(
- prompt_inputs
- )
+ prompt = prompt_template.format(
+ prompt_inputs
+ )
+ elif prompt_item.edition_type == 'jinja2':
+ prompt = raw_prompt
+ prompt_inputs = inputs
+
+ prompt = Jinja2Formatter.format(prompt, prompt_inputs)
+ else:
+ raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt))
diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py
index 2be00bdf0e..23a8602bea 100644
--- a/api/core/prompt/entities/advanced_prompt_entities.py
+++ b/api/core/prompt/entities/advanced_prompt_entities.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Literal, Optional
from pydantic import BaseModel
@@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel):
"""
text: str
role: PromptMessageRole
+ edition_type: Optional[Literal['basic', 'jinja2']]
class CompletionModelPromptTemplate(BaseModel):
@@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel):
Completion Model Prompt Template.
"""
text: str
+ edition_type: Optional[Literal['basic', 'jinja2']]
class MemoryConfig(BaseModel):
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index c390aaf8c9..1e48a10bc7 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -4,6 +4,7 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
+from core.workflow.entities.variable_entities import VariableSelector
class ModelConfig(BaseModel):
@@ -37,13 +38,31 @@ class VisionConfig(BaseModel):
enabled: bool
configs: Optional[Configs] = None
+class PromptConfig(BaseModel):
+ """
+ Prompt Config.
+ """
+ jinja2_variables: Optional[list[VariableSelector]] = None
+
+class LLMNodeChatModelMessage(ChatModelMessage):
+ """
+ LLM Node Chat Model Message.
+ """
+ jinja2_text: Optional[str] = None
+
+class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
+ """
+ LLM Node Chat Model Prompt Template.
+ """
+ jinja2_text: Optional[str] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
- prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
+ prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
+ prompt_config: Optional[PromptConfig] = None
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig
diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py
index c8b7f279ab..fef09c1385 100644
--- a/api/core/workflow/nodes/llm/llm_node.py
+++ b/api/core/workflow/nodes/llm/llm_node.py
@@ -1,4 +1,6 @@
+import json
from collections.abc import Generator
+from copy import deepcopy
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
-from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
-from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
+from core.workflow.nodes.llm.entities import (
+ LLMNodeChatModelMessage,
+ LLMNodeCompletionModelPromptTemplate,
+ LLMNodeData,
+ ModelConfig,
+)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
@@ -39,16 +45,24 @@ class LLMNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
- node_data = self.node_data
- node_data = cast(self._node_data_cls, node_data)
+ node_data = cast(LLMNodeData, deepcopy(self.node_data))
node_inputs = None
process_data = None
try:
+ # init messages template
+ node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
+
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
+ # fetch jinja2 inputs
+ jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
+
+ # merge inputs
+ inputs.update(jinja_inputs)
+
node_inputs = {}
# fetch files
@@ -183,6 +197,86 @@ class LLMNode(BaseNode):
usage = LLMUsage.empty_usage()
return full_text, usage
+
+ def _transform_chat_messages(self,
+ messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
+ ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
+ """
+ Transform chat messages
+
+ :param messages: chat messages
+ :return:
+ """
+
+ if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
+ if messages.edition_type == 'jinja2':
+ messages.text = messages.jinja2_text
+
+ return messages
+
+ for message in messages:
+ if message.edition_type == 'jinja2':
+ message.text = message.jinja2_text
+
+ return messages
+
+ def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
+ """
+ Fetch jinja inputs
+ :param node_data: node data
+ :param variable_pool: variable pool
+ :return:
+ """
+ variables = {}
+
+ if not node_data.prompt_config:
+ return variables
+
+ for variable_selector in node_data.prompt_config.jinja2_variables or []:
+ variable = variable_selector.variable
+ value = variable_pool.get_variable_value(
+ variable_selector=variable_selector.value_selector
+ )
+
+ def parse_dict(d: dict) -> str:
+ """
+ Parse dict into string
+ """
+ # check if it's a context structure
+ if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
+ return d['content']
+
+ # else, parse the dict
+ try:
+ return json.dumps(d, ensure_ascii=False)
+ except Exception:
+ return str(d)
+
+ if isinstance(value, str):
+ value = value
+ elif isinstance(value, list):
+ result = ''
+ for item in value:
+ if isinstance(item, dict):
+ result += parse_dict(item)
+ elif isinstance(item, str):
+ result += item
+ elif isinstance(item, int | float):
+ result += str(item)
+ else:
+ result += str(item)
+ result += '\n'
+ value = result.strip()
+ elif isinstance(value, dict):
+ value = parse_dict(value)
+ elif isinstance(value, int | float):
+ value = str(value)
+ else:
+ value = str(value)
+
+ variables[variable] = value
+
+ return variables
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
@@ -531,25 +625,25 @@ class LLMNode(BaseNode):
db.session.commit()
@classmethod
- def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+ def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
- node_data = node_data
- node_data = cast(cls._node_data_cls, node_data)
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
- variable_template_parser = VariableTemplateParser(template=prompt.text)
- variable_selectors.extend(variable_template_parser.extract_variable_selectors())
+ if prompt.edition_type != 'jinja2':
+ variable_template_parser = VariableTemplateParser(template=prompt.text)
+ variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
- variable_template_parser = VariableTemplateParser(template=prompt_template.text)
- variable_selectors = variable_template_parser.extract_variable_selectors()
+ if prompt_template.edition_type != 'jinja2':
+ variable_template_parser = VariableTemplateParser(template=prompt_template.text)
+ variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
@@ -571,6 +665,22 @@ class LLMNode(BaseNode):
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
+ if node_data.prompt_config:
+ enable_jinja = False
+
+ if isinstance(prompt_template, list):
+ for prompt in prompt_template:
+ if prompt.edition_type == 'jinja2':
+ enable_jinja = True
+ break
+ else:
+ if prompt_template.edition_type == 'jinja2':
+ enable_jinja = True
+
+ if enable_jinja:
+ for variable_selector in node_data.prompt_config.jinja2_variables or []:
+ variable_mapping[variable_selector.variable] = variable_selector.value_selector
+
return variable_mapping
@classmethod
@@ -588,7 +698,8 @@ class LLMNode(BaseNode):
"prompts": [
{
"role": "system",
- "text": "You are a helpful AI assistant."
+ "text": "You are a helpful AI assistant.",
+ "edition_type": "basic"
}
]
},
@@ -600,7 +711,8 @@ class LLMNode(BaseNode):
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
" XML tags.\n\n\n{{"
- "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
+ "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
+ "edition_type": "basic"
},
"stop": ["Human:"]
}
diff --git a/api/requirements-dev.txt b/api/requirements-dev.txt
index 0391ac5969..70b2ce2ef5 100644
--- a/api/requirements-dev.txt
+++ b/api/requirements-dev.txt
@@ -3,3 +3,4 @@ pytest~=8.1.1
pytest-benchmark~=4.0.0
pytest-env~=1.1.3
pytest-mock~=3.14.0
+jinja2~=3.1.2
\ No newline at end of file
diff --git a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py
index 38517cf448..ef84c92625 100644
--- a/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py
+++ b/api/tests/integration_tests/workflow/nodes/__mock/code_executor.py
@@ -3,6 +3,7 @@ from typing import Literal
import pytest
from _pytest.monkeypatch import MonkeyPatch
+from jinja2 import Template
from core.helper.code_executor.code_executor import CodeExecutor
@@ -18,7 +19,7 @@ class MockedCodeExecutor:
}
elif language == 'jinja2':
return {
- "result": "3"
+ "result": Template(code).render(inputs)
}
@pytest.fixture
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index 8a8a58d59f..d04497a187 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -1,3 +1,4 @@
+import json
import os
from unittest.mock import MagicMock
@@ -19,6 +20,7 @@ from models.workflow import WorkflowNodeExecutionStatus
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@@ -116,3 +118,118 @@ def test_execute_llm(setup_openai_mock):
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['text'] is not None
assert result.outputs['usage']['total_tokens'] > 0
+
+@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
+@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
+def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
+ """
+ Test execute LLM node with jinja2
+ """
+ node = LLMNode(
+ tenant_id='1',
+ app_id='1',
+ workflow_id='1',
+ user_id='1',
+ user_from=UserFrom.ACCOUNT,
+ config={
+ 'id': 'llm',
+ 'data': {
+ 'title': '123',
+ 'type': 'llm',
+ 'model': {
+ 'provider': 'openai',
+ 'name': 'gpt-3.5-turbo',
+ 'mode': 'chat',
+ 'completion_params': {}
+ },
+ 'prompt_config': {
+ 'jinja2_variables': [{
+ 'variable': 'sys_query',
+ 'value_selector': ['sys', 'query']
+ }, {
+ 'variable': 'output',
+ 'value_selector': ['abc', 'output']
+ }]
+ },
+ 'prompt_template': [
+ {
+ 'role': 'system',
+ 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
+ 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
+ 'edition_type': 'jinja2'
+ },
+ {
+ 'role': 'user',
+ 'text': '{{#sys.query#}}',
+ 'jinja2_text': '{{sys_query}}',
+ 'edition_type': 'basic'
+ }
+ ],
+ 'memory': None,
+ 'context': {
+ 'enabled': False
+ },
+ 'vision': {
+ 'enabled': False
+ }
+ }
+ }
+ )
+
+ # construct variable pool
+ pool = VariablePool(system_variables={
+ SystemVariable.QUERY: 'what\'s the weather today?',
+ SystemVariable.FILES: [],
+ SystemVariable.CONVERSATION_ID: 'abababa',
+ SystemVariable.USER_ID: 'aaa'
+ }, user_inputs={})
+ pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
+
+ credentials = {
+ 'openai_api_key': os.environ.get('OPENAI_API_KEY')
+ }
+
+ provider_instance = ModelProviderFactory().get_provider_instance('openai')
+ model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+ provider_model_bundle = ProviderModelBundle(
+ configuration=ProviderConfiguration(
+ tenant_id='1',
+ provider=provider_instance.get_provider_schema(),
+ preferred_provider_type=ProviderType.CUSTOM,
+ using_provider_type=ProviderType.CUSTOM,
+ system_configuration=SystemConfiguration(
+ enabled=False
+ ),
+ custom_configuration=CustomConfiguration(
+ provider=CustomProviderConfiguration(
+ credentials=credentials
+ )
+ )
+ ),
+ provider_instance=provider_instance,
+ model_type_instance=model_type_instance
+ )
+
+ model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
+
+ model_config = ModelConfigWithCredentialsEntity(
+ model='gpt-3.5-turbo',
+ provider='openai',
+ mode='chat',
+ credentials=credentials,
+ parameters={},
+ model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
+ provider_model_bundle=provider_model_bundle
+ )
+
+ # Mock db.session.close()
+ db.session.close = MagicMock()
+
+ node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
+
+ # execute node
+ result = node.run(pool)
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert 'sunny' in json.dumps(result.process_data)
+ assert 'what\'s the weather today?' in json.dumps(result.process_data)
\ No newline at end of file