mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 09:56:06 +08:00
fix: tool call message role according to credentials (#5625)
Co-authored-by: sunxichen <sun.xc@digitalcnzz.com>
This commit is contained in:
parent
92c56fdf2b
commit
bafc8a0bde
@ -88,7 +88,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return self._num_tokens_from_messages(model, prompt_messages, tools)
|
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -305,7 +305,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
||||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
elif completion_type is LLMMode.COMPLETION:
|
||||||
endpoint_url = urljoin(endpoint_url, 'completions')
|
endpoint_url = urljoin(endpoint_url, 'completions')
|
||||||
data['prompt'] = prompt_messages[0].content
|
data['prompt'] = prompt_messages[0].content
|
||||||
@ -582,7 +582,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: dict = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Convert PromptMessage to dict for OpenAI API format
|
Convert PromptMessage to dict for OpenAI API format
|
||||||
"""
|
"""
|
||||||
@ -636,7 +636,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
# "tool_call_id": message.tool_call_id
|
# "tool_call_id": message.tool_call_id
|
||||||
# }
|
# }
|
||||||
message_dict = {
|
message_dict = {
|
||||||
"role": "function",
|
"role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function",
|
||||||
"content": message.content,
|
"content": message.content,
|
||||||
"name": message.tool_call_id
|
"name": message.tool_call_id
|
||||||
}
|
}
|
||||||
@ -675,7 +675,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int:
|
||||||
"""
|
"""
|
||||||
Approximate num tokens with GPT2 tokenizer.
|
Approximate num tokens with GPT2 tokenizer.
|
||||||
"""
|
"""
|
||||||
@ -684,7 +684,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
tokens_per_name = 1
|
tokens_per_name = 1
|
||||||
|
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
|
||||||
for message in messages_dict:
|
for message in messages_dict:
|
||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user