mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 20:55:53 +08:00
feat: support doubao llm function calling (#5100)
This commit is contained in:
parent
25b0a97851
commit
0ce97e6315
@ -7,7 +7,9 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
|
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
|
||||||
@ -36,10 +38,11 @@ class MaaSClient(MaasService):
|
|||||||
client.set_sk(sk)
|
client.set_sk(sk)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict:
|
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
|
||||||
req = {
|
req = {
|
||||||
'parameters': params,
|
'parameters': params,
|
||||||
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages]
|
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
|
||||||
|
**extra_model_kwargs,
|
||||||
}
|
}
|
||||||
if not stream:
|
if not stream:
|
||||||
return super().chat(
|
return super().chat(
|
||||||
@ -89,10 +92,22 @@ class MaaSClient(MaasService):
|
|||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
message_dict = {'role': ChatRole.ASSISTANT,
|
message_dict = {'role': ChatRole.ASSISTANT,
|
||||||
'content': message.content}
|
'content': message.content}
|
||||||
|
if message.tool_calls:
|
||||||
|
message_dict['tool_calls'] = [
|
||||||
|
{
|
||||||
|
'name': call.function.name,
|
||||||
|
'arguments': call.function.arguments
|
||||||
|
} for call in message.tool_calls
|
||||||
|
]
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {'role': ChatRole.SYSTEM,
|
message_dict = {'role': ChatRole.SYSTEM,
|
||||||
'content': message.content}
|
'content': message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {'role': ChatRole.FUNCTION,
|
||||||
|
'content': message.content,
|
||||||
|
'name': message.tool_call_id}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown PromptMessage type {message}")
|
raise ValueError(f"Got unknown PromptMessage type {message}")
|
||||||
|
|
||||||
@ -106,3 +121,14 @@ class MaaSClient(MaasService):
|
|||||||
raise wrap_error(e)
|
raise wrap_error(e)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -119,8 +119,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
if stop:
|
if stop:
|
||||||
req_params['stop'] = stop
|
req_params['stop'] = stop
|
||||||
|
|
||||||
|
extra_model_kwargs = {}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
extra_model_kwargs['tools'] = [
|
||||||
|
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
|
||||||
|
]
|
||||||
|
|
||||||
resp = MaaSClient.wrap_exception(
|
resp = MaaSClient.wrap_exception(
|
||||||
lambda: client.chat(req_params, prompt_messages, stream))
|
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
||||||
if not stream:
|
if not stream:
|
||||||
return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
||||||
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
||||||
@ -156,12 +163,26 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
message = choice['message']
|
message = choice['message']
|
||||||
|
|
||||||
|
# parse tool calls
|
||||||
|
tool_calls = []
|
||||||
|
if message['tool_calls']:
|
||||||
|
for call in message['tool_calls']:
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=call['function']['name'],
|
||||||
|
type=call['type'],
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=call['function']['name'],
|
||||||
|
arguments=call['function']['arguments']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
return LLMResult(
|
return LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=message['content'] if message['content'] else '',
|
content=message['content'] if message['content'] else '',
|
||||||
tool_calls=[],
|
tool_calls=tool_calls,
|
||||||
),
|
),
|
||||||
usage=self._calc_usage(model, credentials, resp['usage']),
|
usage=self._calc_usage(model, credentials, resp['usage']),
|
||||||
)
|
)
|
||||||
@ -252,6 +273,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
if credentials.get('context_size'):
|
if credentials.get('context_size'):
|
||||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||||
credentials.get('context_size', 4096))
|
credentials.get('context_size', 4096))
|
||||||
|
|
||||||
|
model_features = ModelConfigs.get(
|
||||||
|
credentials['base_model_name'], {}).get('features', [])
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
@ -260,7 +285,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
model_properties=model_properties,
|
model_properties=model_properties,
|
||||||
parameter_rules=rules
|
parameter_rules=rules,
|
||||||
|
features=model_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
|
|
||||||
ModelConfigs = {
|
ModelConfigs = {
|
||||||
'Doubao-pro-4k': {
|
'Doubao-pro-4k': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
@ -7,7 +9,10 @@ ModelConfigs = {
|
|||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 4096,
|
'context_size': 4096,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [
|
||||||
|
ModelFeature.TOOL_CALL
|
||||||
|
],
|
||||||
},
|
},
|
||||||
'Doubao-lite-4k': {
|
'Doubao-lite-4k': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
@ -17,7 +22,10 @@ ModelConfigs = {
|
|||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 4096,
|
'context_size': 4096,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [
|
||||||
|
ModelFeature.TOOL_CALL
|
||||||
|
],
|
||||||
},
|
},
|
||||||
'Doubao-pro-32k': {
|
'Doubao-pro-32k': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
@ -27,7 +35,10 @@ ModelConfigs = {
|
|||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 32768,
|
'context_size': 32768,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [
|
||||||
|
ModelFeature.TOOL_CALL
|
||||||
|
],
|
||||||
},
|
},
|
||||||
'Doubao-lite-32k': {
|
'Doubao-lite-32k': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
@ -37,7 +48,10 @@ ModelConfigs = {
|
|||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 32768,
|
'context_size': 32768,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [
|
||||||
|
ModelFeature.TOOL_CALL
|
||||||
|
],
|
||||||
},
|
},
|
||||||
'Doubao-pro-128k': {
|
'Doubao-pro-128k': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
@ -47,7 +61,10 @@ ModelConfigs = {
|
|||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 131072,
|
'context_size': 131072,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [
|
||||||
|
ModelFeature.TOOL_CALL
|
||||||
|
],
|
||||||
},
|
},
|
||||||
'Doubao-lite-128k': {
|
'Doubao-lite-128k': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
@ -57,7 +74,10 @@ ModelConfigs = {
|
|||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 131072,
|
'context_size': 131072,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [
|
||||||
|
ModelFeature.TOOL_CALL
|
||||||
|
],
|
||||||
},
|
},
|
||||||
'Skylark2-pro-4k': {
|
'Skylark2-pro-4k': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
@ -67,26 +87,29 @@ ModelConfigs = {
|
|||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 4096,
|
'context_size': 4096,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [],
|
||||||
},
|
},
|
||||||
'Llama3-8B': {
|
'Llama3-8B': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
'max_prompt_tokens': 8192,
|
'max_prompt_tokens': 8192,
|
||||||
'max_new_tokens': 8192,
|
'max_new_tokens': 8192,
|
||||||
},
|
},
|
||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 8192,
|
'context_size': 8192,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [],
|
||||||
},
|
},
|
||||||
'Llama3-70B': {
|
'Llama3-70B': {
|
||||||
'req_params': {
|
'req_params': {
|
||||||
'max_prompt_tokens': 8192,
|
'max_prompt_tokens': 8192,
|
||||||
'max_new_tokens': 8192,
|
'max_new_tokens': 8192,
|
||||||
},
|
},
|
||||||
'model_properties': {
|
'model_properties': {
|
||||||
'context_size': 8192,
|
'context_size': 8192,
|
||||||
'mode': 'chat',
|
'mode': 'chat',
|
||||||
}
|
},
|
||||||
|
'features': [],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user