mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 01:56:04 +08:00
Feat/open ai compatible functioncall (#2783)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
f8951d7f57
commit
e54c9cd401
@ -472,7 +472,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
if message.name is not None:
|
if message.name:
|
||||||
message_dict["user_name"] = message.name
|
message_dict["user_name"] = message.name
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
|||||||
min: 1
|
min: 1
|
||||||
max: 8000
|
max: 8000
|
||||||
- name: safe_prompt
|
- name: safe_prompt
|
||||||
defulat: false
|
default: false
|
||||||
type: boolean
|
type: boolean
|
||||||
help:
|
help:
|
||||||
en_US: Whether to inject a safety prompt before all conversations.
|
en_US: Whether to inject a safety prompt before all conversations.
|
||||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
|||||||
min: 1
|
min: 1
|
||||||
max: 8000
|
max: 8000
|
||||||
- name: safe_prompt
|
- name: safe_prompt
|
||||||
defulat: false
|
default: false
|
||||||
type: boolean
|
type: boolean
|
||||||
help:
|
help:
|
||||||
en_US: Whether to inject a safety prompt before all conversations.
|
en_US: Whether to inject a safety prompt before all conversations.
|
||||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
|||||||
min: 1
|
min: 1
|
||||||
max: 8000
|
max: 8000
|
||||||
- name: safe_prompt
|
- name: safe_prompt
|
||||||
defulat: false
|
default: false
|
||||||
type: boolean
|
type: boolean
|
||||||
help:
|
help:
|
||||||
en_US: Whether to inject a safety prompt before all conversations.
|
en_US: Whether to inject a safety prompt before all conversations.
|
||||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
|||||||
min: 1
|
min: 1
|
||||||
max: 2048
|
max: 2048
|
||||||
- name: safe_prompt
|
- name: safe_prompt
|
||||||
defulat: false
|
default: false
|
||||||
type: boolean
|
type: boolean
|
||||||
help:
|
help:
|
||||||
en_US: Whether to inject a safety prompt before all conversations.
|
en_US: Whether to inject a safety prompt before all conversations.
|
||||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
|||||||
min: 1
|
min: 1
|
||||||
max: 8000
|
max: 8000
|
||||||
- name: safe_prompt
|
- name: safe_prompt
|
||||||
defulat: false
|
default: false
|
||||||
type: boolean
|
type: boolean
|
||||||
help:
|
help:
|
||||||
en_US: Whether to inject a safety prompt before all conversations.
|
en_US: Whether to inject a safety prompt before all conversations.
|
||||||
|
@ -25,6 +25,7 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
DefaultParameterName,
|
DefaultParameterName,
|
||||||
FetchFrom,
|
FetchFrom,
|
||||||
|
ModelFeature,
|
||||||
ModelPropertyKey,
|
ModelPropertyKey,
|
||||||
ModelType,
|
ModelType,
|
||||||
ParameterRule,
|
ParameterRule,
|
||||||
@ -166,11 +167,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
generate custom model entities from credentials
|
generate custom model entities from credentials
|
||||||
"""
|
"""
|
||||||
|
support_function_call = False
|
||||||
|
features = []
|
||||||
|
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||||
|
if function_calling_type == 'function_call':
|
||||||
|
features = [ModelFeature.TOOL_CALL]
|
||||||
|
support_function_call = True
|
||||||
|
endpoint_url = credentials["endpoint_url"]
|
||||||
|
# if not endpoint_url.endswith('/'):
|
||||||
|
# endpoint_url += '/'
|
||||||
|
# if 'https://api.openai.com/v1/' == endpoint_url:
|
||||||
|
# features = [ModelFeature.STREAM_TOOL_CALL]
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(en_US=model),
|
label=I18nObject(en_US=model),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
features=features if support_function_call else [],
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
||||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||||
@ -194,14 +207,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
max=1,
|
max=1,
|
||||||
precision=2
|
precision=2
|
||||||
),
|
),
|
||||||
ParameterRule(
|
|
||||||
name="top_k",
|
|
||||||
label=I18nObject(en_US="Top K"),
|
|
||||||
type=ParameterType.INT,
|
|
||||||
default=int(credentials.get('top_k', 1)),
|
|
||||||
min=1,
|
|
||||||
max=100
|
|
||||||
),
|
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||||
label=I18nObject(en_US="Frequency Penalty"),
|
label=I18nObject(en_US="Frequency Penalty"),
|
||||||
@ -232,7 +237,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
output=Decimal(credentials.get('output_price', 0)),
|
output=Decimal(credentials.get('output_price', 0)),
|
||||||
unit=Decimal(credentials.get('unit', 0)),
|
unit=Decimal(credentials.get('unit', 0)),
|
||||||
currency=credentials.get('currency', "USD")
|
currency=credentials.get('currency', "USD")
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if credentials['mode'] == 'chat':
|
if credentials['mode'] == 'chat':
|
||||||
@ -292,14 +297,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
raise ValueError("Unsupported completion type for model configuration.")
|
raise ValueError("Unsupported completion type for model configuration.")
|
||||||
|
|
||||||
# annotate tools with names, descriptions, etc.
|
# annotate tools with names, descriptions, etc.
|
||||||
|
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||||
formatted_tools = []
|
formatted_tools = []
|
||||||
if tools:
|
if tools:
|
||||||
data["tool_choice"] = "auto"
|
if function_calling_type == 'function_call':
|
||||||
|
data['functions'] = [{
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters
|
||||||
|
} for tool in tools]
|
||||||
|
elif function_calling_type == 'tool_call':
|
||||||
|
data["tool_choice"] = "auto"
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||||
|
|
||||||
data["tools"] = formatted_tools
|
data["tools"] = formatted_tools
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
data["stop"] = stop
|
data["stop"] = stop
|
||||||
@ -367,9 +380,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||||
if chunk:
|
if chunk:
|
||||||
#ignore sse comments
|
# ignore sse comments
|
||||||
if chunk.startswith(':'):
|
if chunk.startswith(':'):
|
||||||
continue
|
continue
|
||||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||||
chunk_json = None
|
chunk_json = None
|
||||||
try:
|
try:
|
||||||
@ -452,10 +465,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
response_content = ''
|
response_content = ''
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
|
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
response_content = output.get('message', {})['content']
|
response_content = output.get('message', {})['content']
|
||||||
tool_calls = output.get('message', {}).get('tool_calls')
|
if function_calling_type == 'tool_call':
|
||||||
|
tool_calls = output.get('message', {}).get('tool_calls')
|
||||||
|
elif function_calling_type == 'function_call':
|
||||||
|
tool_calls = output.get('message', {}).get('function_call')
|
||||||
|
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
elif completion_type is LLMMode.COMPLETION:
|
||||||
response_content = output['text']
|
response_content = output['text']
|
||||||
@ -463,7 +479,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
if function_calling_type == 'tool_call':
|
||||||
|
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
||||||
|
elif function_calling_type == 'function_call':
|
||||||
|
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
||||||
|
|
||||||
usage = response_json.get("usage")
|
usage = response_json.get("usage")
|
||||||
if usage:
|
if usage:
|
||||||
@ -522,33 +541,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||||
in
|
# in
|
||||||
message.tool_calls]
|
# message.tool_calls]
|
||||||
# function_call = message.tool_calls[0]
|
|
||||||
# message_dict["function_call"] = {
|
function_call = message.tool_calls[0]
|
||||||
# "name": function_call.function.name,
|
message_dict["function_call"] = {
|
||||||
# "arguments": function_call.function.arguments,
|
"name": function_call.function.name,
|
||||||
# }
|
"arguments": function_call.function.arguments,
|
||||||
|
}
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
elif isinstance(message, ToolPromptMessage):
|
elif isinstance(message, ToolPromptMessage):
|
||||||
message = cast(ToolPromptMessage, message)
|
message = cast(ToolPromptMessage, message)
|
||||||
message_dict = {
|
|
||||||
"role": "tool",
|
|
||||||
"content": message.content,
|
|
||||||
"tool_call_id": message.tool_call_id
|
|
||||||
}
|
|
||||||
# message_dict = {
|
# message_dict = {
|
||||||
# "role": "function",
|
# "role": "tool",
|
||||||
# "content": message.content,
|
# "content": message.content,
|
||||||
# "name": message.tool_call_id
|
# "tool_call_id": message.tool_call_id
|
||||||
# }
|
# }
|
||||||
|
message_dict = {
|
||||||
|
"role": "function",
|
||||||
|
"content": message.content,
|
||||||
|
"name": message.tool_call_id
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
if message.name is not None:
|
if message.name:
|
||||||
message_dict["name"] = message.name
|
message_dict["name"] = message.name
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
@ -693,3 +713,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
|
def _extract_response_function_call(self, response_function_call) \
|
||||||
|
-> AssistantPromptMessage.ToolCall:
|
||||||
|
"""
|
||||||
|
Extract function call from response
|
||||||
|
|
||||||
|
:param response_function_call: response function call
|
||||||
|
:return: tool call
|
||||||
|
"""
|
||||||
|
tool_call = None
|
||||||
|
if response_function_call:
|
||||||
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=response_function_call['name'],
|
||||||
|
arguments=response_function_call['arguments']
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=response_function_call['name'],
|
||||||
|
type="function",
|
||||||
|
function=function
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_call
|
||||||
|
@ -75,6 +75,28 @@ model_credential_schema:
|
|||||||
value: llm
|
value: llm
|
||||||
default: '4096'
|
default: '4096'
|
||||||
type: text-input
|
type: text-input
|
||||||
|
- variable: function_calling_type
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
label:
|
||||||
|
en_US: Function calling
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: no_call
|
||||||
|
options:
|
||||||
|
- value: function_call
|
||||||
|
label:
|
||||||
|
en_US: Support
|
||||||
|
zh_Hans: 支持
|
||||||
|
# - value: tool_call
|
||||||
|
# label:
|
||||||
|
# en_US: Tool Call
|
||||||
|
# zh_Hans: Tool Call
|
||||||
|
- value: no_call
|
||||||
|
label:
|
||||||
|
en_US: Not Support
|
||||||
|
zh_Hans: 不支持
|
||||||
- variable: stream_mode_delimiter
|
- variable: stream_mode_delimiter
|
||||||
label:
|
label:
|
||||||
zh_Hans: 流模式返回结果的分隔符
|
zh_Hans: 流模式返回结果的分隔符
|
||||||
|
Loading…
x
Reference in New Issue
Block a user