diff --git a/api/core/completion.py b/api/core/completion.py index d05d8417da..8e79e03a93 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -177,13 +177,21 @@ Avoid mentioning that you obtained the information from the context. And answer according to the language of the user's question. """ if pre_prompt: - human_inputs.update(inputs) + extra_inputs = {k: inputs[k] for k in + OutLinePromptTemplate.from_template(template=pre_prompt).input_variables + if k in inputs} + if extra_inputs: + human_inputs.update(extra_inputs) human_message_instruction += pre_prompt + "\n" human_message_prompt = human_message_instruction + "Q:{query}\nA:" else: if pre_prompt: - human_inputs.update(inputs) + extra_inputs = {k: inputs[k] for k in + OutLinePromptTemplate.from_template(template=pre_prompt).input_variables + if k in inputs} + if extra_inputs: + human_inputs.update(extra_inputs) human_message_prompt = pre_prompt + "\n" + human_message_prompt # construct main prompt