From 3da179f77b6d6ac2babf8b0f1aef57a7b5cfd631 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 24 Apr 2024 17:20:01 +0800 Subject: [PATCH] feat: add conversation_id and user_id in chatflow/workflow system vars (#3771) Co-authored-by: Joel --- api/core/app/apps/advanced_chat/app_runner.py | 13 +++++++-- .../advanced_chat/generate_task_pipeline.py | 8 +++++- api/core/app/apps/workflow/app_runner.py | 13 +++++++-- .../apps/workflow/generate_task_pipeline.py | 6 +++++ api/core/workflow/entities/node_entities.py | 3 ++- api/core/workflow/nodes/llm/llm_node.py | 5 +++- api/core/workflow/nodes/start/start_node.py | 5 +--- .../workflow/nodes/test_llm.py | 3 ++- .../prompt/test_simple_prompt_transform.py | 4 +-- .../core/workflow/nodes/test_answer.py | 1 + .../core/workflow/nodes/test_if_else.py | 2 ++ .../workflow/test_workflow_converter.py | 4 +-- .../nodes/_base/components/variable/utils.ts | 8 ++++++ .../components/workflow/nodes/start/panel.tsx | 27 +++++++++++++++++++ 14 files changed, 86 insertions(+), 16 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 1c54cf3dc5..d858dcac12 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -18,7 +18,7 @@ from core.workflow.entities.node_entities import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.model import App, Conversation, Message +from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -56,6 +56,14 @@ class AdvancedChatAppRunner(AppRunner): query = application_generate_entity.query files = application_generate_entity.files + user_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = application_generate_entity.user_id + # moderation if self.handle_input_moderation( queue_manager=queue_manager, @@ -98,7 +106,8 @@ class AdvancedChatAppRunner(AppRunner): system_inputs={ SystemVariable.QUERY: query, SystemVariable.FILES: files, - SystemVariable.CONVERSATION: conversation.id, + SystemVariable.CONVERSATION_ID: conversation.id, + SystemVariable.USER_ID: user_id }, callbacks=workflow_callbacks ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 9866db12f6..490cb516c6 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -84,13 +84,19 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ super().__init__(application_generate_entity, queue_manager, user, stream) + if isinstance(self._user, EndUser): + user_id = self._user.session_id + else: + user_id = self._user.id + self._workflow = workflow self._conversation = conversation self._message = message self._workflow_system_variables = { SystemVariable.QUERY: message.query, SystemVariable.FILES: application_generate_entity.files, - SystemVariable.CONVERSATION: conversation.id, + SystemVariable.CONVERSATION_ID: conversation.id, + SystemVariable.USER_ID: user_id } self._task_state = AdvancedChatTaskState( diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 4de6f28290..9d854afe35 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -14,7 +14,7 @@ from core.workflow.entities.node_entities import SystemVariable from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db -from models.model import App +from models.model import App, EndUser from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -36,6 +36,14 @@ class WorkflowAppRunner: app_config = application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) + user_id = None + if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = application_generate_entity.user_id + app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: raise ValueError("App not found") @@ -67,7 +75,8 @@ class WorkflowAppRunner: else UserFrom.END_USER, user_inputs=inputs, system_inputs={ - SystemVariable.FILES: files + SystemVariable.FILES: files, + SystemVariable.USER_ID: user_id }, callbacks=workflow_callbacks ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index f926d75968..68095b0ab6 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -71,9 +71,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ super().__init__(application_generate_entity, queue_manager, user, stream) + if isinstance(self._user, EndUser): + user_id = self._user.session_id + else: + user_id = self._user.id + self._workflow = workflow self._workflow_system_variables = { SystemVariable.FILES: application_generate_entity.files, + SystemVariable.USER_ID: user_id } self._task_state = WorkflowTaskState() diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 72d881d3d5..7eb9488792 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -43,7 +43,8 @@ class SystemVariable(Enum): """ QUERY = 'query' FILES = 'files' - CONVERSATION = 'conversation' + CONVERSATION_ID = 'conversation_id' + USER_ID = 'user_id' @classmethod def value_of(cls, value: str) -> 'SystemVariable': diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 00999aa1a6..a894e19a61 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -385,7 +385,7 @@ class LLMNode(BaseNode): return None # get conversation id - conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value]) + conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION_ID.value]) if conversation_id is None: return None @@ -545,6 +545,9 @@ class LLMNode(BaseNode): if node_data.vision.enabled: variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + if node_data.memory: + variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] + return variable_mapping @classmethod diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index e32e850a23..fd51a6c476 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,6 +1,6 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData @@ -21,9 +21,6 @@ class StartNode(BaseNode): cleaned_inputs = variable_pool.user_inputs for var in variable_pool.system_variables: - if var == SystemVariable.CONVERSATION: - continue - cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var] return NodeRunResult( diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index c0c431912a..8a8a58d59f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -65,7 +65,8 @@ def test_execute_llm(setup_openai_mock): pool = VariablePool(system_variables={ SystemVariable.QUERY: 'what\'s the weather today?', SystemVariable.FILES: [], - SystemVariable.CONVERSATION: 'abababa' + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' }, user_inputs={}) pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index ad72837ae2..7e32ecbbdb 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -238,8 +238,8 @@ def test__get_completion_model_prompt_messages(): prompt_rules = prompt_template['prompt_rules'] full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( max_token_limit=2000, - ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' )} real_prompt = prompt_template['prompt_template'].format(full_inputs) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index e2d5be769c..cf21401eb2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -28,6 +28,7 @@ def test_execute_answer(): # construct variable pool pool = VariablePool(system_variables={ SystemVariable.FILES: [], + SystemVariable.USER_ID: 'aaa' }, user_inputs={}) pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny') pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.') diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 7b402ad0a0..99413540c5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -118,6 +118,7 @@ def test_execute_if_else_result_true(): # construct variable pool pool = VariablePool(system_variables={ SystemVariable.FILES: [], + SystemVariable.USER_ID: 'aaa' }, user_inputs={}) pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def']) pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def']) @@ -179,6 +180,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool(system_variables={ SystemVariable.FILES: [], + SystemVariable.USER_ID: 'aaa' }, user_inputs={}) pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def']) pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def']) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 140c84bf47..29d55df8c3 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -89,7 +89,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): ) ] - nodes = workflow_converter._convert_to_http_request_node( + nodes, _ = workflow_converter._convert_to_http_request_node( app_model=app_model, variables=default_variables, external_data_variables=external_data_variables @@ -159,7 +159,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): ) ] - nodes = workflow_converter._convert_to_http_request_node( + nodes, _ = workflow_converter._convert_to_http_request_node( app_model=app_model, variables=default_variables, external_data_variables=external_data_variables diff --git a/web/app/components/workflow/nodes/_base/components/variable/utils.ts b/web/app/components/workflow/nodes/_base/components/variable/utils.ts index 969eeea976..6d82ed7e09 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/utils.ts +++ b/web/app/components/workflow/nodes/_base/components/variable/utils.ts @@ -80,7 +80,15 @@ const formatItem = (item: any, isChatMode: boolean, filterVar: (payload: Var, se variable: 'sys.query', type: VarType.string, }) + res.vars.push({ + variable: 'sys.conversation_id', + type: VarType.string, + }) } + res.vars.push({ + variable: 'sys.user_id', + type: VarType.string, + }) res.vars.push({ variable: 'sys.files', type: VarType.arrayFile, diff --git a/web/app/components/workflow/nodes/start/panel.tsx b/web/app/components/workflow/nodes/start/panel.tsx index 1ae5abc191..48b5d6b7c2 100644 --- a/web/app/components/workflow/nodes/start/panel.tsx +++ b/web/app/components/workflow/nodes/start/panel.tsx @@ -70,6 +70,7 @@ const Panel: FC> = ({ } />) } + > = ({ } /> + { + isChatMode && ( + + String + + } + /> + ) + } + + String + + } + />