Feat/optimize chat prompt (#158)

This commit is contained in:
John Wang 2023-05-23 12:26:28 +08:00 committed by GitHub
parent 7722a7c5cd
commit 90150a6ca9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,7 +39,8 @@ class Completion:
memory = cls.get_memory_from_conversation( memory = cls.get_memory_from_conversation(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
app_model_config=app_model_config, app_model_config=app_model_config,
conversation=conversation conversation=conversation,
return_messages=False
) )
inputs = conversation.inputs inputs = conversation.inputs
@ -119,7 +120,8 @@ class Completion:
return response return response
@classmethod @classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str], def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Union[str | List[BaseMessage]]: Union[str | List[BaseMessage]]:
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
@ -161,11 +163,19 @@ And answer according to the language of the user's question.
"query": query "query": query
} }
human_message_prompt = "{query}" human_message_prompt = ""
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if chain_output: if chain_output:
human_inputs['context'] = chain_output human_inputs['context'] = chain_output
human_message_instruction = """Use the following CONTEXT as your learned knowledge. human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT] [CONTEXT]
{context} {context}
[END CONTEXT] [END CONTEXT]
@ -176,23 +186,27 @@ When answer to user:
Avoid mentioning that you obtained the information from the context. Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question. And answer according to the language of the user's question.
""" """
if pre_prompt:
extra_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if extra_inputs:
human_inputs.update(extra_inputs)
human_message_instruction += pre_prompt + "\n"
human_message_prompt = human_message_instruction + "Q:{query}\nA:" if pre_prompt:
else: human_message_prompt += pre_prompt
if pre_prompt:
extra_inputs = {k: inputs[k] for k in query_prompt = "\nHuman: {query}\nAI: "
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs} if memory:
if extra_inputs: # append chat histories
human_inputs.update(extra_inputs) tmp_human_message = PromptBuilder.to_human_message(
human_message_prompt = pre_prompt + "\n" + human_message_prompt prompt_content=human_message_prompt + query_prompt,
inputs=human_inputs
)
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
- memory.llm.max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" + history_messages
human_message_prompt += query_prompt
# construct main prompt # construct main prompt
human_message = PromptBuilder.to_human_message( human_message = PromptBuilder.to_human_message(
@ -200,23 +214,14 @@ And answer according to the language of the user's question.
inputs=human_inputs inputs=human_inputs
) )
if memory:
# append chat histories
tmp_messages = messages.copy() + [human_message]
curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
rest_tokens = llm_constant.max_context_token_length[
memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
messages += history_messages
messages.append(human_message) messages.append(human_message)
return messages return messages
@classmethod @classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager: streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming: if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
@ -228,7 +233,7 @@ And answer according to the language of the user's question.
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> \ max_token_limit: int) -> \
List[BaseMessage]: str:
"""Get memory messages.""" """Get memory messages."""
memory.max_token_limit = max_token_limit memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0] memory_key = memory.memory_variables[0]