mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 04:28:58 +08:00
feat: support configurate openai compatible stream tool call (#3467)
This commit is contained in:
parent
5b3133f9fc
commit
8f8e9de601
@ -170,13 +170,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
features = []
|
features = []
|
||||||
|
|
||||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||||
if function_calling_type == 'function_call':
|
if function_calling_type in ['function_call']:
|
||||||
features.append(ModelFeature.TOOL_CALL)
|
features.append(ModelFeature.TOOL_CALL)
|
||||||
endpoint_url = credentials["endpoint_url"]
|
elif function_calling_type in ['tool_call']:
|
||||||
# if not endpoint_url.endswith('/'):
|
features.append(ModelFeature.MULTI_TOOL_CALL)
|
||||||
# endpoint_url += '/'
|
|
||||||
# if 'https://api.openai.com/v1/' == endpoint_url:
|
stream_function_calling = credentials.get('stream_function_calling', 'supported')
|
||||||
# features.append(ModelFeature.STREAM_TOOL_CALL)
|
if stream_function_calling == 'supported':
|
||||||
|
features.append(ModelFeature.STREAM_TOOL_CALL)
|
||||||
|
|
||||||
vision_support = credentials.get('vision_support', 'not_support')
|
vision_support = credentials.get('vision_support', 'not_support')
|
||||||
if vision_support == 'support':
|
if vision_support == 'support':
|
||||||
@ -386,30 +387,38 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||||
def get_tool_call(tool_call_id: str):
|
def get_tool_call(tool_call_id: str):
|
||||||
tool_call = next(
|
if not tool_call_id:
|
||||||
(tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None
|
return tools_calls[-1]
|
||||||
)
|
|
||||||
|
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||||
if tool_call is None:
|
if tool_call is None:
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id='',
|
id=tool_call_id,
|
||||||
type='function',
|
type="function",
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name='',
|
name="",
|
||||||
arguments=''
|
arguments=""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tools_calls.append(tool_call)
|
tools_calls.append(tool_call)
|
||||||
|
|
||||||
return tool_call
|
return tool_call
|
||||||
|
|
||||||
for new_tool_call in new_tool_calls:
|
for new_tool_call in new_tool_calls:
|
||||||
# get tool call
|
# get tool call
|
||||||
tool_call = get_tool_call(new_tool_call.id)
|
tool_call = get_tool_call(new_tool_call.function.name)
|
||||||
# update tool call
|
# update tool call
|
||||||
|
if new_tool_call.id:
|
||||||
tool_call.id = new_tool_call.id
|
tool_call.id = new_tool_call.id
|
||||||
|
if new_tool_call.type:
|
||||||
tool_call.type = new_tool_call.type
|
tool_call.type = new_tool_call.type
|
||||||
|
if new_tool_call.function.name:
|
||||||
tool_call.function.name = new_tool_call.function.name
|
tool_call.function.name = new_tool_call.function.name
|
||||||
|
if new_tool_call.function.arguments:
|
||||||
tool_call.function.arguments += new_tool_call.function.arguments
|
tool_call.function.arguments += new_tool_call.function.arguments
|
||||||
|
|
||||||
|
finish_reason = 'Unknown'
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||||
if chunk:
|
if chunk:
|
||||||
# ignore sse comments
|
# ignore sse comments
|
||||||
@ -438,7 +447,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
delta = choice['delta']
|
delta = choice['delta']
|
||||||
delta_content = delta.get('content')
|
delta_content = delta.get('content')
|
||||||
|
|
||||||
|
assistant_message_tool_calls = None
|
||||||
|
|
||||||
|
if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call':
|
||||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||||
|
elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call':
|
||||||
|
assistant_message_tool_calls = [{
|
||||||
|
'id': 'tool_call_id',
|
||||||
|
'type': 'function',
|
||||||
|
'function': delta.get('function_call', {})
|
||||||
|
}]
|
||||||
|
|
||||||
# assistant_message_function_call = delta.delta.function_call
|
# assistant_message_function_call = delta.delta.function_call
|
||||||
|
|
||||||
# extract tool calls from response
|
# extract tool calls from response
|
||||||
@ -449,15 +468,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
if delta_content is None or delta_content == '':
|
if delta_content is None or delta_content == '':
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# function_call = self._extract_response_function_call(assistant_message_function_call)
|
|
||||||
# tool_calls = [function_call] if function_call else []
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta_content,
|
content=delta_content,
|
||||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# reset tool calls
|
||||||
|
tool_calls = []
|
||||||
full_assistant_content += delta_content
|
full_assistant_content += delta_content
|
||||||
elif 'text' in choice:
|
elif 'text' in choice:
|
||||||
choice_text = choice.get('text', '')
|
choice_text = choice.get('text', '')
|
||||||
@ -470,26 +487,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check payload indicator for completion
|
|
||||||
if finish_reason is not None:
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=chunk_index,
|
|
||||||
message=AssistantPromptMessage(
|
|
||||||
tool_calls=tools_calls,
|
|
||||||
),
|
|
||||||
finish_reason=finish_reason
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield create_final_llm_result_chunk(
|
|
||||||
index=chunk_index,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
finish_reason=finish_reason
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -501,6 +498,25 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
|
||||||
|
if tools_calls:
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=chunk_index,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
tool_calls=tools_calls,
|
||||||
|
content=""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield create_final_llm_result_chunk(
|
||||||
|
index=chunk_index,
|
||||||
|
message=AssistantPromptMessage(content=""),
|
||||||
|
finish_reason=finish_reason
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
|
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||||
|
|
||||||
@ -757,13 +773,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
if response_tool_calls:
|
if response_tool_calls:
|
||||||
for response_tool_call in response_tool_calls:
|
for response_tool_call in response_tool_calls:
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_tool_call["function"]["name"],
|
name=response_tool_call.get("function", {}).get("name", ""),
|
||||||
arguments=response_tool_call["function"]["arguments"]
|
arguments=response_tool_call.get("function", {}).get("arguments", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=response_tool_call["id"],
|
id=response_tool_call.get("id", ""),
|
||||||
type=response_tool_call["type"],
|
type=response_tool_call.get("type", ""),
|
||||||
function=function
|
function=function
|
||||||
)
|
)
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
@ -781,12 +797,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
tool_call = None
|
tool_call = None
|
||||||
if response_function_call:
|
if response_function_call:
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_function_call['name'],
|
name=response_function_call.get('name', ''),
|
||||||
arguments=response_function_call['arguments']
|
arguments=response_function_call.get('arguments', '')
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=response_function_call['name'],
|
id=response_function_call.get('id', ''),
|
||||||
type="function",
|
type="function",
|
||||||
function=function
|
function=function
|
||||||
)
|
)
|
||||||
|
@ -86,14 +86,32 @@ model_credential_schema:
|
|||||||
default: no_call
|
default: no_call
|
||||||
options:
|
options:
|
||||||
- value: function_call
|
- value: function_call
|
||||||
|
label:
|
||||||
|
en_US: Function Call
|
||||||
|
zh_Hans: Function Call
|
||||||
|
- value: tool_call
|
||||||
|
label:
|
||||||
|
en_US: Tool Call
|
||||||
|
zh_Hans: Tool Call
|
||||||
|
- value: no_call
|
||||||
|
label:
|
||||||
|
en_US: Not Support
|
||||||
|
zh_Hans: 不支持
|
||||||
|
- variable: stream_function_calling
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
label:
|
||||||
|
en_US: Stream function calling
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: not_supported
|
||||||
|
options:
|
||||||
|
- value: supported
|
||||||
label:
|
label:
|
||||||
en_US: Support
|
en_US: Support
|
||||||
zh_Hans: 支持
|
zh_Hans: 支持
|
||||||
# - value: tool_call
|
- value: not_supported
|
||||||
# label:
|
|
||||||
# en_US: Tool Call
|
|
||||||
# zh_Hans: Tool Call
|
|
||||||
- value: no_call
|
|
||||||
label:
|
label:
|
||||||
en_US: Not Support
|
en_US: Not Support
|
||||||
zh_Hans: 不支持
|
zh_Hans: 不支持
|
||||||
|
Loading…
x
Reference in New Issue
Block a user