mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 16:19:02 +08:00
feat: support function call for ollama block chat api (#10784)
This commit is contained in:
parent
7e66e5a713
commit
fbfc811a44
@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
endpoint_url = urljoin(endpoint_url, "api/chat")
|
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||||
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||||
|
if tools:
|
||||||
|
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
|
||||||
else:
|
else:
|
||||||
endpoint_url = urljoin(endpoint_url, "api/generate")
|
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||||
first_prompt_message = prompt_messages[0]
|
first_prompt_message = prompt_messages[0]
|
||||||
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
|
||||||
|
|
||||||
def _handle_generate_response(
|
def _handle_generate_response(
|
||||||
self,
|
self,
|
||||||
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
completion_type: LLMMode,
|
completion_type: LLMMode,
|
||||||
response: requests.Response,
|
response: requests.Response,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]],
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm completion response
|
Handle llm completion response
|
||||||
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: llm result
|
:return: llm result
|
||||||
"""
|
"""
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
tool_calls = []
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
message = response_json.get("message", {})
|
message = response_json.get("message", {})
|
||||||
response_content = message.get("content", "")
|
response_content = message.get("content", "")
|
||||||
|
response_tool_calls = message.get("tool_calls", [])
|
||||||
|
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
|
||||||
else:
|
else:
|
||||||
response_content = response_json["response"]
|
response_content = response_json["response"]
|
||||||
|
|
||||||
assistant_message = AssistantPromptMessage(content=response_content)
|
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
|
||||||
|
|
||||||
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
|
||||||
|
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
|
||||||
|
"""
|
||||||
|
Convert PromptMessageTool to dict for Ollama API
|
||||||
|
|
||||||
|
:param tool: tool
|
||||||
|
:return: tool dict
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||||
"""
|
"""
|
||||||
Convert PromptMessage to dict for Ollama API
|
Convert PromptMessage to dict for Ollama API
|
||||||
|
|
||||||
|
:param message: prompt message
|
||||||
|
:return: message dict
|
||||||
"""
|
"""
|
||||||
if isinstance(message, UserPromptMessage):
|
if isinstance(message, UserPromptMessage):
|
||||||
message = cast(UserPromptMessage, message)
|
message = cast(UserPromptMessage, message)
|
||||||
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {"role": "tool", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
|
||||||
|
"""
|
||||||
|
Extract response tool call
|
||||||
|
"""
|
||||||
|
tool_call = None
|
||||||
|
if response_tool_call and "function" in response_tool_call:
|
||||||
|
# Convert arguments to JSON string if it's a dict
|
||||||
|
arguments = response_tool_call.get("function").get("arguments")
|
||||||
|
if isinstance(arguments, dict):
|
||||||
|
arguments = json.dumps(arguments)
|
||||||
|
|
||||||
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=response_tool_call.get("function").get("name"),
|
||||||
|
arguments=arguments,
|
||||||
|
)
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=response_tool_call.get("function").get("name"),
|
||||||
|
type="function",
|
||||||
|
function=function,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_call
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema.
|
Get customizable model schema.
|
||||||
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
:return: model schema
|
:return: model schema
|
||||||
"""
|
"""
|
||||||
extras = {}
|
extras = {
|
||||||
|
"features": [],
|
||||||
|
}
|
||||||
|
|
||||||
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||||
extras["features"] = [ModelFeature.VISION]
|
extras["features"].append(ModelFeature.VISION)
|
||||||
|
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
|
||||||
|
extras["features"].append(ModelFeature.TOOL_CALL)
|
||||||
|
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -96,3 +96,22 @@ model_credential_schema:
|
|||||||
label:
|
label:
|
||||||
en_US: 'No'
|
en_US: 'No'
|
||||||
zh_Hans: 否
|
zh_Hans: 否
|
||||||
|
- variable: function_call_support
|
||||||
|
label:
|
||||||
|
zh_Hans: 是否支持函数调用
|
||||||
|
en_US: Function call support
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
default: 'false'
|
||||||
|
type: radio
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- value: 'true'
|
||||||
|
label:
|
||||||
|
en_US: 'Yes'
|
||||||
|
zh_Hans: 是
|
||||||
|
- value: 'false'
|
||||||
|
label:
|
||||||
|
en_US: 'No'
|
||||||
|
zh_Hans: 否
|
||||||
|
Loading…
x
Reference in New Issue
Block a user