mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 04:25:59 +08:00
Feat/optimize chat prompt (#158)
This commit is contained in:
parent
7722a7c5cd
commit
90150a6ca9
@ -39,7 +39,8 @@ class Completion:
|
|||||||
memory = cls.get_memory_from_conversation(
|
memory = cls.get_memory_from_conversation(
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
app_model_config=app_model_config,
|
app_model_config=app_model_config,
|
||||||
conversation=conversation
|
conversation=conversation,
|
||||||
|
return_messages=False
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs = conversation.inputs
|
inputs = conversation.inputs
|
||||||
@ -119,7 +120,8 @@ class Completion:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
|
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
||||||
|
chain_output: Optional[str],
|
||||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||||
Union[str | List[BaseMessage]]:
|
Union[str | List[BaseMessage]]:
|
||||||
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
|
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
|
||||||
@ -161,11 +163,19 @@ And answer according to the language of the user's question.
|
|||||||
"query": query
|
"query": query
|
||||||
}
|
}
|
||||||
|
|
||||||
human_message_prompt = "{query}"
|
human_message_prompt = ""
|
||||||
|
|
||||||
|
if pre_prompt:
|
||||||
|
pre_prompt_inputs = {k: inputs[k] for k in
|
||||||
|
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
|
||||||
|
if k in inputs}
|
||||||
|
|
||||||
|
if pre_prompt_inputs:
|
||||||
|
human_inputs.update(pre_prompt_inputs)
|
||||||
|
|
||||||
if chain_output:
|
if chain_output:
|
||||||
human_inputs['context'] = chain_output
|
human_inputs['context'] = chain_output
|
||||||
human_message_instruction = """Use the following CONTEXT as your learned knowledge.
|
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
|
||||||
[CONTEXT]
|
[CONTEXT]
|
||||||
{context}
|
{context}
|
||||||
[END CONTEXT]
|
[END CONTEXT]
|
||||||
@ -176,23 +186,27 @@ When answer to user:
|
|||||||
Avoid mentioning that you obtained the information from the context.
|
Avoid mentioning that you obtained the information from the context.
|
||||||
And answer according to the language of the user's question.
|
And answer according to the language of the user's question.
|
||||||
"""
|
"""
|
||||||
if pre_prompt:
|
|
||||||
extra_inputs = {k: inputs[k] for k in
|
|
||||||
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
|
|
||||||
if k in inputs}
|
|
||||||
if extra_inputs:
|
|
||||||
human_inputs.update(extra_inputs)
|
|
||||||
human_message_instruction += pre_prompt + "\n"
|
|
||||||
|
|
||||||
human_message_prompt = human_message_instruction + "Q:{query}\nA:"
|
if pre_prompt:
|
||||||
else:
|
human_message_prompt += pre_prompt
|
||||||
if pre_prompt:
|
|
||||||
extra_inputs = {k: inputs[k] for k in
|
query_prompt = "\nHuman: {query}\nAI: "
|
||||||
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
|
|
||||||
if k in inputs}
|
if memory:
|
||||||
if extra_inputs:
|
# append chat histories
|
||||||
human_inputs.update(extra_inputs)
|
tmp_human_message = PromptBuilder.to_human_message(
|
||||||
human_message_prompt = pre_prompt + "\n" + human_message_prompt
|
prompt_content=human_message_prompt + query_prompt,
|
||||||
|
inputs=human_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
|
||||||
|
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
|
||||||
|
- memory.llm.max_tokens - curr_message_tokens
|
||||||
|
rest_tokens = max(rest_tokens, 0)
|
||||||
|
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||||
|
human_message_prompt += "\n\n" + history_messages
|
||||||
|
|
||||||
|
human_message_prompt += query_prompt
|
||||||
|
|
||||||
# construct main prompt
|
# construct main prompt
|
||||||
human_message = PromptBuilder.to_human_message(
|
human_message = PromptBuilder.to_human_message(
|
||||||
@ -200,23 +214,14 @@ And answer according to the language of the user's question.
|
|||||||
inputs=human_inputs
|
inputs=human_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
if memory:
|
|
||||||
# append chat histories
|
|
||||||
tmp_messages = messages.copy() + [human_message]
|
|
||||||
curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
|
|
||||||
rest_tokens = llm_constant.max_context_token_length[
|
|
||||||
memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
|
|
||||||
rest_tokens = max(rest_tokens, 0)
|
|
||||||
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
|
|
||||||
messages += history_messages
|
|
||||||
|
|
||||||
messages.append(human_message)
|
messages.append(human_message)
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||||
streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
|
streaming: bool,
|
||||||
|
conversation_message_task: ConversationMessageTask) -> CallbackManager:
|
||||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||||
if streaming:
|
if streaming:
|
||||||
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
||||||
@ -228,7 +233,7 @@ And answer according to the language of the user's question.
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||||
max_token_limit: int) -> \
|
max_token_limit: int) -> \
|
||||||
List[BaseMessage]:
|
str:
|
||||||
"""Get memory messages."""
|
"""Get memory messages."""
|
||||||
memory.max_token_limit = max_token_limit
|
memory.max_token_limit = max_token_limit
|
||||||
memory_key = memory.memory_variables[0]
|
memory_key = memory.memory_variables[0]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user