mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 00:45:53 +08:00
feat: support doubao llm function calling (#5100)
This commit is contained in:
parent
25b0a97851
commit
0ce97e6315
@ -7,7 +7,9 @@ from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
|
||||
@ -36,10 +38,11 @@ class MaaSClient(MaasService):
|
||||
client.set_sk(sk)
|
||||
return client
|
||||
|
||||
def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict:
|
||||
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
|
||||
req = {
|
||||
'parameters': params,
|
||||
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages]
|
||||
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
|
||||
**extra_model_kwargs,
|
||||
}
|
||||
if not stream:
|
||||
return super().chat(
|
||||
@ -89,10 +92,22 @@ class MaaSClient(MaasService):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.ASSISTANT,
|
||||
'content': message.content}
|
||||
if message.tool_calls:
|
||||
message_dict['tool_calls'] = [
|
||||
{
|
||||
'name': call.function.name,
|
||||
'arguments': call.function.arguments
|
||||
} for call in message.tool_calls
|
||||
]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.SYSTEM,
|
||||
'content': message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {'role': ChatRole.FUNCTION,
|
||||
'content': message.content,
|
||||
'name': message.tool_call_id}
|
||||
else:
|
||||
raise ValueError(f"Got unknown PromptMessage type {message}")
|
||||
|
||||
@ -106,3 +121,14 @@ class MaaSClient(MaasService):
|
||||
raise wrap_error(e)
|
||||
|
||||
return resp
|
||||
|
||||
@staticmethod
|
||||
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
}
|
||||
|
@ -119,8 +119,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
if stop:
|
||||
req_params['stop'] = stop
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [
|
||||
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
|
||||
]
|
||||
|
||||
resp = MaaSClient.wrap_exception(
|
||||
lambda: client.chat(req_params, prompt_messages, stream))
|
||||
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
||||
if not stream:
|
||||
return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
||||
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
||||
@ -156,12 +163,26 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
choice = choices[0]
|
||||
message = choice['message']
|
||||
|
||||
# parse tool calls
|
||||
tool_calls = []
|
||||
if message['tool_calls']:
|
||||
for call in message['tool_calls']:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=call['function']['name'],
|
||||
type=call['type'],
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=call['function']['name'],
|
||||
arguments=call['function']['arguments']
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=message['content'] if message['content'] else '',
|
||||
tool_calls=[],
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
usage=self._calc_usage(model, credentials, resp['usage']),
|
||||
)
|
||||
@ -252,6 +273,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
if credentials.get('context_size'):
|
||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||
credentials.get('context_size', 4096))
|
||||
|
||||
model_features = ModelConfigs.get(
|
||||
credentials['base_model_name'], {}).get('features', [])
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
@ -260,7 +285,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties=model_properties,
|
||||
parameter_rules=rules
|
||||
parameter_rules=rules,
|
||||
features=model_features,
|
||||
)
|
||||
|
||||
return entity
|
||||
|
@ -1,3 +1,5 @@
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
ModelConfigs = {
|
||||
'Doubao-pro-4k': {
|
||||
'req_params': {
|
||||
@ -7,7 +9,10 @@ ModelConfigs = {
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-4k': {
|
||||
'req_params': {
|
||||
@ -17,7 +22,10 @@ ModelConfigs = {
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-pro-32k': {
|
||||
'req_params': {
|
||||
@ -27,7 +35,10 @@ ModelConfigs = {
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-32k': {
|
||||
'req_params': {
|
||||
@ -37,7 +48,10 @@ ModelConfigs = {
|
||||
'model_properties': {
|
||||
'context_size': 32768,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-pro-128k': {
|
||||
'req_params': {
|
||||
@ -47,7 +61,10 @@ ModelConfigs = {
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Doubao-lite-128k': {
|
||||
'req_params': {
|
||||
@ -57,7 +74,10 @@ ModelConfigs = {
|
||||
'model_properties': {
|
||||
'context_size': 131072,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [
|
||||
ModelFeature.TOOL_CALL
|
||||
],
|
||||
},
|
||||
'Skylark2-pro-4k': {
|
||||
'req_params': {
|
||||
@ -67,26 +87,29 @@ ModelConfigs = {
|
||||
'model_properties': {
|
||||
'context_size': 4096,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Llama3-8B': {
|
||||
'req_params': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 8192,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [],
|
||||
},
|
||||
'Llama3-70B': {
|
||||
'req_params': {
|
||||
'req_params': {
|
||||
'max_prompt_tokens': 8192,
|
||||
'max_new_tokens': 8192,
|
||||
},
|
||||
'model_properties': {
|
||||
'context_size': 8192,
|
||||
'mode': 'chat',
|
||||
}
|
||||
},
|
||||
'features': [],
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user