feat: add conversation_id and user_id in chatflow/workflow system vars (#3771)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
takatost 2024-04-24 17:20:01 +08:00 committed by GitHub
parent a34e8cb0bd
commit 3da179f77b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 86 additions and 16 deletions

View File

@ -18,7 +18,7 @@ from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, Message
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -56,6 +56,14 @@ class AdvancedChatAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
@ -98,7 +106,8 @@ class AdvancedChatAppRunner(AppRunner):
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION: conversation.id,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks
)

View File

@ -84,13 +84,19 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser):
user_id = self._user.session_id
else:
user_id = self._user.id
self._workflow = workflow
self._conversation = conversation
self._message = message
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION: conversation.id,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
}
self._task_state = AdvancedChatTaskState(

View File

@ -14,7 +14,7 @@ from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -36,6 +36,14 @@ class WorkflowAppRunner:
app_config = application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
@ -67,7 +75,8 @@ class WorkflowAppRunner:
else UserFrom.END_USER,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks
)

View File

@ -71,9 +71,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser):
user_id = self._user.session_id
else:
user_id = self._user.id
self._workflow = workflow
self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id
}
self._task_state = WorkflowTaskState()

View File

@ -43,7 +43,8 @@ class SystemVariable(Enum):
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION = 'conversation'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
@classmethod
def value_of(cls, value: str) -> 'SystemVariable':

View File

@ -385,7 +385,7 @@ class LLMNode(BaseNode):
return None
# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION_ID.value])
if conversation_id is None:
return None
@ -545,6 +545,9 @@ class LLMNode(BaseNode):
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
return variable_mapping
@classmethod

View File

@ -1,6 +1,6 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
@ -21,9 +21,6 @@ class StartNode(BaseNode):
cleaned_inputs = variable_pool.user_inputs
for var in variable_pool.system_variables:
if var == SystemVariable.CONVERSATION:
continue
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
return NodeRunResult(

View File

@ -65,7 +65,8 @@ def test_execute_llm(setup_openai_mock):
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION: 'abababa'
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')

View File

@ -238,8 +238,8 @@ def test__get_completion_model_prompt_messages():
prompt_rules = prompt_template['prompt_rules']
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
max_token_limit=2000,
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
)}
real_prompt = prompt_template['prompt_template'].format(full_inputs)

View File

@ -28,6 +28,7 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')

View File

@ -118,6 +118,7 @@ def test_execute_if_else_result_true():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
@ -179,6 +180,7 @@ def test_execute_if_else_result_false():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def'])
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def'])

View File

@ -89,7 +89,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
)
]
nodes = workflow_converter._convert_to_http_request_node(
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables
@ -159,7 +159,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
)
]
nodes = workflow_converter._convert_to_http_request_node(
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables

View File

@ -80,7 +80,15 @@ const formatItem = (item: any, isChatMode: boolean, filterVar: (payload: Var, se
variable: 'sys.query',
type: VarType.string,
})
res.vars.push({
variable: 'sys.conversation_id',
type: VarType.string,
})
}
res.vars.push({
variable: 'sys.user_id',
type: VarType.string,
})
res.vars.push({
variable: 'sys.files',
type: VarType.arrayFile,

View File

@ -70,6 +70,7 @@ const Panel: FC<NodePanelProps<StartNodeType>> = ({
}
/>)
}
<VarItem
readonly
payload={{
@ -81,6 +82,32 @@ const Panel: FC<NodePanelProps<StartNodeType>> = ({
</div>
}
/>
{
isChatMode && (
<VarItem
readonly
payload={{
variable: 'sys.conversation_id',
} as any}
rightContent={
<div className='text-xs font-normal text-gray-500'>
String
</div>
}
/>
)
}
<VarItem
readonly
payload={{
variable: 'sys.user_id',
} as any}
rightContent={
<div className='text-xs font-normal text-gray-500'>
String
</div>
}
/>
</div>
</>