mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 18:39:06 +08:00
feat: support LLM jinja2 template prompt (#3968)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
parent
897e07f639
commit
8578ee0864
17
api/core/helper/code_executor/jinja2_formatter.py
Normal file
17
api/core/helper/code_executor/jinja2_formatter.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from core.helper.code_executor.code_executor import CodeExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class Jinja2Formatter:
|
||||||
|
@classmethod
|
||||||
|
def format(cls, template: str, inputs: str) -> str:
|
||||||
|
"""
|
||||||
|
Format template
|
||||||
|
:param template: template
|
||||||
|
:param inputs: inputs
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = CodeExecutor.execute_workflow_code_template(
|
||||||
|
language='jinja2', code=template, inputs=inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
return result['result']
|
@ -2,6 +2,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
|
from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -80,6 +81,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
|
|
||||||
|
if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
|
||||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||||
|
|
||||||
@ -103,6 +105,11 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
prompt = prompt_template.format(
|
prompt = prompt_template.format(
|
||||||
prompt_inputs
|
prompt_inputs
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
prompt = raw_prompt
|
||||||
|
prompt_inputs = inputs
|
||||||
|
|
||||||
|
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||||
@ -135,6 +142,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
for prompt_item in raw_prompt_list:
|
for prompt_item in raw_prompt_list:
|
||||||
raw_prompt = prompt_item.text
|
raw_prompt = prompt_item.text
|
||||||
|
|
||||||
|
if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
|
||||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||||
|
|
||||||
@ -143,6 +151,13 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
prompt = prompt_template.format(
|
prompt = prompt_template.format(
|
||||||
prompt_inputs
|
prompt_inputs
|
||||||
)
|
)
|
||||||
|
elif prompt_item.edition_type == 'jinja2':
|
||||||
|
prompt = raw_prompt
|
||||||
|
prompt_inputs = inputs
|
||||||
|
|
||||||
|
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')
|
||||||
|
|
||||||
if prompt_item.role == PromptMessageRole.USER:
|
if prompt_item.role == PromptMessageRole.USER:
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -11,6 +11,7 @@ class ChatModelMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
text: str
|
text: str
|
||||||
role: PromptMessageRole
|
role: PromptMessageRole
|
||||||
|
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||||
|
|
||||||
|
|
||||||
class CompletionModelPromptTemplate(BaseModel):
|
class CompletionModelPromptTemplate(BaseModel):
|
||||||
@ -18,6 +19,7 @@ class CompletionModelPromptTemplate(BaseModel):
|
|||||||
Completion Model Prompt Template.
|
Completion Model Prompt Template.
|
||||||
"""
|
"""
|
||||||
text: str
|
text: str
|
||||||
|
edition_type: Optional[Literal['basic', 'jinja2']]
|
||||||
|
|
||||||
|
|
||||||
class MemoryConfig(BaseModel):
|
class MemoryConfig(BaseModel):
|
||||||
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
@ -37,13 +38,31 @@ class VisionConfig(BaseModel):
|
|||||||
enabled: bool
|
enabled: bool
|
||||||
configs: Optional[Configs] = None
|
configs: Optional[Configs] = None
|
||||||
|
|
||||||
|
class PromptConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Prompt Config.
|
||||||
|
"""
|
||||||
|
jinja2_variables: Optional[list[VariableSelector]] = None
|
||||||
|
|
||||||
|
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||||
|
"""
|
||||||
|
LLM Node Chat Model Message.
|
||||||
|
"""
|
||||||
|
jinja2_text: Optional[str] = None
|
||||||
|
|
||||||
|
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||||
|
"""
|
||||||
|
LLM Node Chat Model Prompt Template.
|
||||||
|
"""
|
||||||
|
jinja2_text: Optional[str] = None
|
||||||
|
|
||||||
class LLMNodeData(BaseNodeData):
|
class LLMNodeData(BaseNodeData):
|
||||||
"""
|
"""
|
||||||
LLM Node Data.
|
LLM Node Data.
|
||||||
"""
|
"""
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
|
||||||
|
prompt_config: Optional[PromptConfig] = None
|
||||||
memory: Optional[MemoryConfig] = None
|
memory: Optional[MemoryConfig] = None
|
||||||
context: ContextConfig
|
context: ContextConfig
|
||||||
vision: VisionConfig
|
vision: VisionConfig
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
@ -17,11 +19,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
from core.workflow.nodes.llm.entities import (
|
||||||
|
LLMNodeChatModelMessage,
|
||||||
|
LLMNodeCompletionModelPromptTemplate,
|
||||||
|
LLMNodeData,
|
||||||
|
ModelConfig,
|
||||||
|
)
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Conversation
|
from models.model import Conversation
|
||||||
@ -39,16 +45,24 @@ class LLMNode(BaseNode):
|
|||||||
:param variable_pool: variable pool
|
:param variable_pool: variable pool
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
node_data = self.node_data
|
node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
||||||
node_data = cast(self._node_data_cls, node_data)
|
|
||||||
|
|
||||||
node_inputs = None
|
node_inputs = None
|
||||||
process_data = None
|
process_data = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# init messages template
|
||||||
|
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
|
||||||
|
|
||||||
# fetch variables and fetch values from variable pool
|
# fetch variables and fetch values from variable pool
|
||||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||||
|
|
||||||
|
# fetch jinja2 inputs
|
||||||
|
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
|
||||||
|
|
||||||
|
# merge inputs
|
||||||
|
inputs.update(jinja_inputs)
|
||||||
|
|
||||||
node_inputs = {}
|
node_inputs = {}
|
||||||
|
|
||||||
# fetch files
|
# fetch files
|
||||||
@ -184,6 +198,86 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
return full_text, usage
|
return full_text, usage
|
||||||
|
|
||||||
|
def _transform_chat_messages(self,
|
||||||
|
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||||
|
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||||
|
"""
|
||||||
|
Transform chat messages
|
||||||
|
|
||||||
|
:param messages: chat messages
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||||
|
if messages.edition_type == 'jinja2':
|
||||||
|
messages.text = messages.jinja2_text
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message.edition_type == 'jinja2':
|
||||||
|
message.text = message.jinja2_text
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Fetch jinja inputs
|
||||||
|
:param node_data: node data
|
||||||
|
:param variable_pool: variable pool
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
variables = {}
|
||||||
|
|
||||||
|
if not node_data.prompt_config:
|
||||||
|
return variables
|
||||||
|
|
||||||
|
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||||
|
variable = variable_selector.variable
|
||||||
|
value = variable_pool.get_variable_value(
|
||||||
|
variable_selector=variable_selector.value_selector
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_dict(d: dict) -> str:
|
||||||
|
"""
|
||||||
|
Parse dict into string
|
||||||
|
"""
|
||||||
|
# check if it's a context structure
|
||||||
|
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
|
||||||
|
return d['content']
|
||||||
|
|
||||||
|
# else, parse the dict
|
||||||
|
try:
|
||||||
|
return json.dumps(d, ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
return str(d)
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = value
|
||||||
|
elif isinstance(value, list):
|
||||||
|
result = ''
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
result += parse_dict(item)
|
||||||
|
elif isinstance(item, str):
|
||||||
|
result += item
|
||||||
|
elif isinstance(item, int | float):
|
||||||
|
result += str(item)
|
||||||
|
else:
|
||||||
|
result += str(item)
|
||||||
|
result += '\n'
|
||||||
|
value = result.strip()
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
value = parse_dict(value)
|
||||||
|
elif isinstance(value, int | float):
|
||||||
|
value = str(value)
|
||||||
|
else:
|
||||||
|
value = str(value)
|
||||||
|
|
||||||
|
variables[variable] = value
|
||||||
|
|
||||||
|
return variables
|
||||||
|
|
||||||
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Fetch inputs
|
Fetch inputs
|
||||||
@ -531,23 +625,23 @@ class LLMNode(BaseNode):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Extract variable selector to variable mapping
|
Extract variable selector to variable mapping
|
||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
node_data = node_data
|
|
||||||
node_data = cast(cls._node_data_cls, node_data)
|
|
||||||
|
|
||||||
prompt_template = node_data.prompt_template
|
prompt_template = node_data.prompt_template
|
||||||
|
|
||||||
variable_selectors = []
|
variable_selectors = []
|
||||||
if isinstance(prompt_template, list):
|
if isinstance(prompt_template, list):
|
||||||
for prompt in prompt_template:
|
for prompt in prompt_template:
|
||||||
|
if prompt.edition_type != 'jinja2':
|
||||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||||
else:
|
else:
|
||||||
|
if prompt_template.edition_type != 'jinja2':
|
||||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||||
|
|
||||||
@ -571,6 +665,22 @@ class LLMNode(BaseNode):
|
|||||||
if node_data.memory:
|
if node_data.memory:
|
||||||
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
||||||
|
|
||||||
|
if node_data.prompt_config:
|
||||||
|
enable_jinja = False
|
||||||
|
|
||||||
|
if isinstance(prompt_template, list):
|
||||||
|
for prompt in prompt_template:
|
||||||
|
if prompt.edition_type == 'jinja2':
|
||||||
|
enable_jinja = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if prompt_template.edition_type == 'jinja2':
|
||||||
|
enable_jinja = True
|
||||||
|
|
||||||
|
if enable_jinja:
|
||||||
|
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||||
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||||
|
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -588,7 +698,8 @@ class LLMNode(BaseNode):
|
|||||||
"prompts": [
|
"prompts": [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"text": "You are a helpful AI assistant."
|
"text": "You are a helpful AI assistant.",
|
||||||
|
"edition_type": "basic"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -600,7 +711,8 @@ class LLMNode(BaseNode):
|
|||||||
"prompt": {
|
"prompt": {
|
||||||
"text": "Here is the chat histories between human and assistant, inside "
|
"text": "Here is the chat histories between human and assistant, inside "
|
||||||
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
||||||
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
|
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
||||||
|
"edition_type": "basic"
|
||||||
},
|
},
|
||||||
"stop": ["Human:"]
|
"stop": ["Human:"]
|
||||||
}
|
}
|
||||||
|
@ -3,3 +3,4 @@ pytest~=8.1.1
|
|||||||
pytest-benchmark~=4.0.0
|
pytest-benchmark~=4.0.0
|
||||||
pytest-env~=1.1.3
|
pytest-env~=1.1.3
|
||||||
pytest-mock~=3.14.0
|
pytest-mock~=3.14.0
|
||||||
|
jinja2~=3.1.2
|
@ -3,6 +3,7 @@ from typing import Literal
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
from core.helper.code_executor.code_executor import CodeExecutor
|
from core.helper.code_executor.code_executor import CodeExecutor
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ class MockedCodeExecutor:
|
|||||||
}
|
}
|
||||||
elif language == 'jinja2':
|
elif language == 'jinja2':
|
||||||
return {
|
return {
|
||||||
"result": "3"
|
"result": Template(code).render(inputs)
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
|||||||
|
|
||||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||||
|
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||||
@ -116,3 +118,118 @@ def test_execute_llm(setup_openai_mock):
|
|||||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
assert result.outputs['text'] is not None
|
assert result.outputs['text'] is not None
|
||||||
assert result.outputs['usage']['total_tokens'] > 0
|
assert result.outputs['usage']['total_tokens'] > 0
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
|
||||||
|
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||||
|
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
||||||
|
"""
|
||||||
|
Test execute LLM node with jinja2
|
||||||
|
"""
|
||||||
|
node = LLMNode(
|
||||||
|
tenant_id='1',
|
||||||
|
app_id='1',
|
||||||
|
workflow_id='1',
|
||||||
|
user_id='1',
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
config={
|
||||||
|
'id': 'llm',
|
||||||
|
'data': {
|
||||||
|
'title': '123',
|
||||||
|
'type': 'llm',
|
||||||
|
'model': {
|
||||||
|
'provider': 'openai',
|
||||||
|
'name': 'gpt-3.5-turbo',
|
||||||
|
'mode': 'chat',
|
||||||
|
'completion_params': {}
|
||||||
|
},
|
||||||
|
'prompt_config': {
|
||||||
|
'jinja2_variables': [{
|
||||||
|
'variable': 'sys_query',
|
||||||
|
'value_selector': ['sys', 'query']
|
||||||
|
}, {
|
||||||
|
'variable': 'output',
|
||||||
|
'value_selector': ['abc', 'output']
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'prompt_template': [
|
||||||
|
{
|
||||||
|
'role': 'system',
|
||||||
|
'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
|
||||||
|
'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
|
||||||
|
'edition_type': 'jinja2'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'text': '{{#sys.query#}}',
|
||||||
|
'jinja2_text': '{{sys_query}}',
|
||||||
|
'edition_type': 'basic'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'memory': None,
|
||||||
|
'context': {
|
||||||
|
'enabled': False
|
||||||
|
},
|
||||||
|
'vision': {
|
||||||
|
'enabled': False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct variable pool
|
||||||
|
pool = VariablePool(system_variables={
|
||||||
|
SystemVariable.QUERY: 'what\'s the weather today?',
|
||||||
|
SystemVariable.FILES: [],
|
||||||
|
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||||
|
SystemVariable.USER_ID: 'aaa'
|
||||||
|
}, user_inputs={})
|
||||||
|
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
|
||||||
|
|
||||||
|
credentials = {
|
||||||
|
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||||
|
}
|
||||||
|
|
||||||
|
provider_instance = ModelProviderFactory().get_provider_instance('openai')
|
||||||
|
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||||
|
provider_model_bundle = ProviderModelBundle(
|
||||||
|
configuration=ProviderConfiguration(
|
||||||
|
tenant_id='1',
|
||||||
|
provider=provider_instance.get_provider_schema(),
|
||||||
|
preferred_provider_type=ProviderType.CUSTOM,
|
||||||
|
using_provider_type=ProviderType.CUSTOM,
|
||||||
|
system_configuration=SystemConfiguration(
|
||||||
|
enabled=False
|
||||||
|
),
|
||||||
|
custom_configuration=CustomConfiguration(
|
||||||
|
provider=CustomProviderConfiguration(
|
||||||
|
credentials=credentials
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
provider_instance=provider_instance,
|
||||||
|
model_type_instance=model_type_instance
|
||||||
|
)
|
||||||
|
|
||||||
|
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
|
||||||
|
|
||||||
|
model_config = ModelConfigWithCredentialsEntity(
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
provider='openai',
|
||||||
|
mode='chat',
|
||||||
|
credentials=credentials,
|
||||||
|
parameters={},
|
||||||
|
model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
|
||||||
|
provider_model_bundle=provider_model_bundle
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock db.session.close()
|
||||||
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
|
node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
|
||||||
|
|
||||||
|
# execute node
|
||||||
|
result = node.run(pool)
|
||||||
|
|
||||||
|
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
assert 'sunny' in json.dumps(result.process_data)
|
||||||
|
assert 'what\'s the weather today?' in json.dumps(result.process_data)
|
Loading…
x
Reference in New Issue
Block a user