diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index d0b9bb872c..7e5581008f 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -222,6 +222,7 @@ class AssistantApplicationRunner(AppRunner): conversation=conversation, message=message, query=query, + inputs=inputs, ) elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: assistant_fc_runner = AssistantFunctionCallApplicationRunner( diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index c9cb2ba61a..0d64920403 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -20,6 +20,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): def run(self, conversation: Conversation, message: Message, query: str, + inputs: Dict[str, str], ) -> Union[Generator, LLMResult]: """ Run Cot agent application @@ -35,6 +36,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): if 'Observation' not in app_orchestration_config.model_config.stop: app_orchestration_config.model_config.stop.append('Observation') + # override inputs + inputs = inputs or {} + instruction = self.app_orchestration_config.prompt_template.simple_prompt_template + instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + iteration_step = 1 max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 @@ -108,7 +114,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): tools=prompt_messages_tools, agent_scratchpad=agent_scratchpad, agent_prompt_message=app_orchestration_config.agent.prompt, - instruction=app_orchestration_config.prompt_template.simple_prompt_template, + instruction=instruction, input=query ) @@ -300,6 +306,18 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): system_fingerprint='' ), PublishFrom.APPLICATION_MANAGER) + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: + """ + fill in inputs from external data tools + """ + for key, value in inputs.items(): + try: + instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) + except Exception as e: + continue + + return instruction + def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit: """ extract response from llm response