diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2fb96679e4..d219156026 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -365,7 +365,7 @@ class ParameterExtractorNode(LLMNode): files=[], context='', memory_config=node_data.memory, - memory=memory, + memory=None, model_config=model_config ) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 3379e8338d..e5fd2bc1fd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -1,5 +1,6 @@ import json import os +from typing import Optional from unittest.mock import MagicMock import pytest @@ -7,6 +8,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration +from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory @@ -61,6 +63,16 @@ def get_mocked_fetch_model_config( return MagicMock(return_value=(model_instance, model_config)) +def get_mocked_fetch_memory(memory_text: str): + class MemoryMock: + def get_history_prompt_text(self, human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None): + return memory_text + + return MagicMock(return_value=MemoryMock()) + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_function_calling_parameter_extractor(setup_openai_mock): """ @@ -354,4 +366,83 @@ def test_extract_json_response(): hello world. """) - assert result['location'] == 'kawaii' \ No newline at end of file + assert result['location'] == 'kawaii' + +@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) +def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): + """ + Test chat parameter extractor with memory. + """ + node = ParameterExtractorNode( + tenant_id='1', + app_id='1', + workflow_id='1', + user_id='1', + invoke_from=InvokeFrom.WEB_APP, + user_from=UserFrom.ACCOUNT, + config={ + 'id': 'llm', + 'data': { + 'title': '123', + 'type': 'parameter-extractor', + 'model': { + 'provider': 'anthropic', + 'name': 'claude-2', + 'mode': 'chat', + 'completion_params': {} + }, + 'query': ['sys', 'query'], + 'parameters': [{ + 'name': 'location', + 'type': 'string', + 'description': 'location', + 'required': True + }], + 'reasoning_mode': 'prompt', + 'instruction': '', + 'memory': { + 'window': { + 'enabled': True, + 'size': 50 + } + }, + } + } + ) + + node._fetch_model_config = get_mocked_fetch_model_config( + provider='anthropic', model='claude-2', mode='chat', credentials={ + 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') + } + ) + node._fetch_memory = get_mocked_fetch_memory('customized memory') + db.session.close = MagicMock() + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + result = node.run(pool) + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs.get('location') == '' + assert result.outputs.get('__reason') == 'Failed to extract result from function call or text response, using empty result.' + prompts = result.process_data.get('prompts') + + latest_role = None + for prompt in prompts: + if prompt.get('role') == 'user': + if '' in prompt.get('text'): + assert '\n{"type": "object"' in prompt.get('text') + elif prompt.get('role') == 'system': + assert 'customized memory' in prompt.get('text') + + if latest_role is not None: + assert latest_role != prompt.get('role') + + if prompt.get('role') in ['user', 'assistant']: + latest_role = prompt.get('role') \ No newline at end of file