mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 16:59:05 +08:00
Feat/open ai compatible functioncall (#2783)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
f8951d7f57
commit
e54c9cd401
@ -472,7 +472,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name is not None:
|
||||
if message.name:
|
||||
message_dict["user_name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 2048
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
@ -25,6 +25,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
@ -166,11 +167,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
support_function_call = False
|
||||
features = []
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'function_call':
|
||||
features = [ModelFeature.TOOL_CALL]
|
||||
support_function_call = True
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
# if not endpoint_url.endswith('/'):
|
||||
# endpoint_url += '/'
|
||||
# if 'https://api.openai.com/v1/' == endpoint_url:
|
||||
# features = [ModelFeature.STREAM_TOOL_CALL]
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features if support_function_call else [],
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||
@ -194,14 +207,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
max=1,
|
||||
precision=2
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
default=int(credentials.get('top_k', 1)),
|
||||
min=1,
|
||||
max=100
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||
label=I18nObject(en_US="Frequency Penalty"),
|
||||
@ -232,7 +237,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
output=Decimal(credentials.get('output_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if credentials['mode'] == 'chat':
|
||||
@ -292,14 +297,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
# annotate tools with names, descriptions, etc.
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
formatted_tools = []
|
||||
if tools:
|
||||
data["tool_choice"] = "auto"
|
||||
if function_calling_type == 'function_call':
|
||||
data['functions'] = [{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
} for tool in tools]
|
||||
elif function_calling_type == 'tool_call':
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
for tool in tools:
|
||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
for tool in tools:
|
||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
|
||||
data["tools"] = formatted_tools
|
||||
data["tools"] = formatted_tools
|
||||
|
||||
if stop:
|
||||
data["stop"] = stop
|
||||
@ -367,9 +380,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
if chunk:
|
||||
#ignore sse comments
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
@ -452,10 +465,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
response_content = ''
|
||||
tool_calls = None
|
||||
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if completion_type is LLMMode.CHAT:
|
||||
response_content = output.get('message', {})['content']
|
||||
tool_calls = output.get('message', {}).get('tool_calls')
|
||||
if function_calling_type == 'tool_call':
|
||||
tool_calls = output.get('message', {}).get('tool_calls')
|
||||
elif function_calling_type == 'function_call':
|
||||
tool_calls = output.get('message', {}).get('function_call')
|
||||
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
response_content = output['text']
|
||||
@ -463,7 +479,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
||||
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
||||
if function_calling_type == 'tool_call':
|
||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
||||
elif function_calling_type == 'function_call':
|
||||
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
||||
|
||||
usage = response_json.get("usage")
|
||||
if usage:
|
||||
@ -522,33 +541,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||
in
|
||||
message.tool_calls]
|
||||
# function_call = message.tool_calls[0]
|
||||
# message_dict["function_call"] = {
|
||||
# "name": function_call.function.name,
|
||||
# "arguments": function_call.function.arguments,
|
||||
# }
|
||||
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||
# in
|
||||
# message.tool_calls]
|
||||
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
# message_dict = {
|
||||
# "role": "function",
|
||||
# "role": "tool",
|
||||
# "content": message.content,
|
||||
# "name": message.tool_call_id
|
||||
# "tool_call_id": message.tool_call_id
|
||||
# }
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name is not None:
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
@ -693,3 +713,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _extract_response_function_call(self, response_function_call) \
|
||||
-> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract function call from response
|
||||
|
||||
:param response_function_call: response function call
|
||||
:return: tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call['name'],
|
||||
arguments=response_function_call['arguments']
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_function_call['name'],
|
||||
type="function",
|
||||
function=function
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
@ -75,6 +75,28 @@ model_credential_schema:
|
||||
value: llm
|
||||
default: '4096'
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: function_call
|
||||
label:
|
||||
en_US: Support
|
||||
zh_Hans: 支持
|
||||
# - value: tool_call
|
||||
# label:
|
||||
# en_US: Tool Call
|
||||
# zh_Hans: Tool Call
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- variable: stream_mode_delimiter
|
||||
label:
|
||||
zh_Hans: 流模式返回结果的分隔符
|
||||
|
Loading…
x
Reference in New Issue
Block a user