feat: Introduce Ark SDK v3 and ensure compatibility with models of SDK v2 (#7579)

Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
sino 2024-08-24 19:29:45 +08:00 committed by GitHub
parent b035c02f78
commit efc136cce5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 604 additions and 192 deletions

View File

@ -1,6 +1,25 @@
import re
from collections.abc import Callable, Generator
from typing import cast
from collections.abc import Generator
from typing import Optional, cast
from volcenginesdkarkruntime import Ark
from volcenginesdkarkruntime.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL
from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function
from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse
from volcenginesdkarkruntime.types.shared_params import FunctionDefinition
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -12,123 +31,171 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
class MaaSClient(MaasService):
def __init__(self, host: str, region: str):
class ArkClientV3:
endpoint_id: Optional[str] = None
ark: Optional[Ark] = None
def __init__(self, *args, **kwargs):
self.ark = Ark(*args, **kwargs)
self.endpoint_id = None
super().__init__(host, region)
def set_endpoint_id(self, endpoint_id: str):
self.endpoint_id = endpoint_id
@classmethod
def from_credential(cls, credentials: dict) -> 'MaaSClient':
host = credentials['api_endpoint_host']
region = credentials['volc_region']
ak = credentials['volc_access_key_id']
sk = credentials['volc_secret_access_key']
endpoint_id = credentials['endpoint_id']
client = cls(host, region)
client.set_endpoint_id(endpoint_id)
client.set_ak(ak)
client.set_sk(sk)
return client
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
req = {
'parameters': params,
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
**extra_model_kwargs,
}
if not stream:
return super().chat(
self.endpoint_id,
req,
)
return super().stream_chat(
self.endpoint_id,
req,
)
def embeddings(self, texts: list[str]) -> dict:
req = {
'input': texts
}
return super().embeddings(self.endpoint_id, req)
@staticmethod
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
def is_legacy(credentials: dict) -> bool:
if ArkClientV3.is_compatible_with_legacy(credentials):
return False
sdk_version = credentials.get("sdk_version", "v2")
return sdk_version != "v3"
@staticmethod
def is_compatible_with_legacy(credentials: dict) -> bool:
sdk_version = credentials.get("sdk_version")
endpoint = credentials.get("api_endpoint_host")
return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com"
@classmethod
def from_credentials(cls, credentials):
"""Initialize the client using the credentials provided."""
args = {
"base_url": credentials['api_endpoint_host'],
"region": credentials['volc_region'],
"ak": credentials['volc_access_key_id'],
"sk": credentials['volc_secret_access_key'],
}
if cls.is_compatible_with_legacy(credentials):
args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3"
client = ArkClientV3(
**args
)
client.endpoint_id = credentials['endpoint_id']
return client
@staticmethod
def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam:
"""Converts a PromptMessage to a ChatCompletionMessageParam"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": ChatRole.USER,
"content": message.content}
content = message.content
else:
content = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
raise ValueError(
'Content object type only support image_url')
content.append(ChatCompletionContentPartTextParam(
text=message_content.text,
type='text',
))
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(
ImagePromptMessageContent, message_content)
image_data = re.sub(
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
content.append({
'type': 'image_url',
'image_url': {
'url': '',
'image_bytes': image_data,
'detail': message_content.detail,
}
})
message_dict = {'role': ChatRole.USER, 'content': content}
content.append(ChatCompletionContentPartImageParam(
image_url=ImageURL(
url=image_data,
detail=message_content.detail.value,
),
type='image_url',
))
message_dict = ChatCompletionUserMessageParam(
role='user',
content=content
)
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {'role': ChatRole.ASSISTANT,
'content': message.content}
if message.tool_calls:
message_dict['tool_calls'] = [
{
'name': call.function.name,
'arguments': call.function.arguments
} for call in message.tool_calls
message_dict = ChatCompletionAssistantMessageParam(
content=message.content,
role='assistant',
tool_calls=None if not message.tool_calls else [
ChatCompletionMessageToolCallParam(
id=call.id,
function=Function(
name=call.function.name,
arguments=call.function.arguments
),
type='function'
) for call in message.tool_calls
]
)
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {'role': ChatRole.SYSTEM,
'content': message.content}
message_dict = ChatCompletionSystemMessageParam(
content=message.content,
role='system'
)
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {'role': ChatRole.FUNCTION,
'content': message.content,
'name': message.tool_call_id}
message_dict = ChatCompletionToolMessageParam(
content=message.content,
role='tool',
tool_call_id=message.tool_call_id
)
else:
raise ValueError(f"Got unknown PromptMessage type {message}")
return message_dict
@staticmethod
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
try:
resp = fn()
except MaasException as e:
raise wrap_error(e)
def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam:
return ChatCompletionToolParam(
type='function',
function=FunctionDefinition(
name=message.name,
description=message.description,
parameters=message.parameters,
)
)
return resp
def chat(self, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> ChatCompletion:
"""Block chat"""
return self.ark.chat.completions.create(
model=self.endpoint_id,
messages=[self.convert_prompt_message(message) for message in messages],
tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
stop=stop,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
top_p=top_p,
temperature=temperature,
)
@staticmethod
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
}
def stream_chat(self, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
) -> Generator[ChatCompletionChunk]:
"""Stream chat"""
chunks = self.ark.chat.completions.create(
stream=True,
model=self.endpoint_id,
messages=[self.convert_prompt_message(message) for message in messages],
tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
stop=stop,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
top_p=top_p,
temperature=temperature,
)
for chunk in chunks:
if not chunk.choices:
continue
yield chunk
def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse:
return self.ark.embeddings.create(model=self.endpoint_id, input=texts)

View File

@ -0,0 +1,134 @@
import re
from collections.abc import Callable, Generator
from typing import cast
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
class MaaSClient(MaasService):
def __init__(self, host: str, region: str):
self.endpoint_id = None
super().__init__(host, region)
def set_endpoint_id(self, endpoint_id: str):
self.endpoint_id = endpoint_id
@classmethod
def from_credential(cls, credentials: dict) -> 'MaaSClient':
host = credentials['api_endpoint_host']
region = credentials['volc_region']
ak = credentials['volc_access_key_id']
sk = credentials['volc_secret_access_key']
endpoint_id = credentials['endpoint_id']
client = cls(host, region)
client.set_endpoint_id(endpoint_id)
client.set_ak(ak)
client.set_sk(sk)
return client
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
req = {
'parameters': params,
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
**extra_model_kwargs,
}
if not stream:
return super().chat(
self.endpoint_id,
req,
)
return super().stream_chat(
self.endpoint_id,
req,
)
def embeddings(self, texts: list[str]) -> dict:
req = {
'input': texts
}
return super().embeddings(self.endpoint_id, req)
@staticmethod
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": ChatRole.USER,
"content": message.content}
else:
content = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
raise ValueError(
'Content object type only support image_url')
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(
ImagePromptMessageContent, message_content)
image_data = re.sub(
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
content.append({
'type': 'image_url',
'image_url': {
'url': '',
'image_bytes': image_data,
'detail': message_content.detail,
}
})
message_dict = {'role': ChatRole.USER, 'content': content}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {'role': ChatRole.ASSISTANT,
'content': message.content}
if message.tool_calls:
message_dict['tool_calls'] = [
{
'name': call.function.name,
'arguments': call.function.arguments
} for call in message.tool_calls
]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {'role': ChatRole.SYSTEM,
'content': message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {'role': ChatRole.FUNCTION,
'content': message.content,
'name': message.tool_call_id}
else:
raise ValueError(f"Got unknown PromptMessage type {message}")
return message_dict
@staticmethod
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
try:
resp = fn()
except MaasException as e:
raise wrap_error(e)
return resp
@staticmethod
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
}

View File

@ -1,4 +1,4 @@
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
class ClientSDKRequestError(MaasException):

View File

@ -1,8 +1,10 @@
import logging
from collections.abc import Generator
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
from core.model_runtime.model_providers.volcengine_maas.errors import (
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors,
BadRequestErrors,
ConnectionErrors,
MaasException,
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
get_model_config,
get_v2_req_params,
get_v3_req_params,
)
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
logger = logging.getLogger(__name__)
@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
if ArkClientV3.is_legacy(credentials):
return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate credentials
"""
# ping
if ArkClientV3.is_legacy(credentials):
return self._validate_credentials_v2(credentials)
return self._validate_credentials_v3(credentials)
@staticmethod
def _validate_credentials_v2(credentials: dict) -> None:
client = MaaSClient.from_credential(credentials)
try:
client.chat(
@ -70,18 +81,24 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
except MaasException as e:
raise CredentialsValidateFailedError(e.message)
@staticmethod
def _validate_credentials_v3(credentials: dict) -> None:
client = ArkClientV3.from_credentials(credentials)
try:
client.chat(max_tokens=16, temperature=0.7, top_p=0.9,
messages=[UserPromptMessage(content='ping\nAnswer: ')], )
except Exception as e:
raise CredentialsValidateFailedError(e)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
if len(prompt_messages) == 0:
if ArkClientV3.is_legacy(credentials):
return self._get_num_tokens_v2(prompt_messages)
return self._get_num_tokens_v3(prompt_messages)
def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int:
if len(messages) == 0:
return 0
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
"""
Calculate num tokens.
:param messages: messages
"""
num_tokens = 0
messages_dict = [
MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
@ -92,9 +109,22 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
return num_tokens
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int:
if len(messages) == 0:
return 0
num_tokens = 0
messages_dict = [
ArkClientV3.convert_prompt_message(m) for m in messages]
for message in messages_dict:
for key, value in message.items():
num_tokens += self._get_num_tokens_by_gpt2(str(key))
num_tokens += self._get_num_tokens_by_gpt2(str(value))
return num_tokens
def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
client = MaaSClient.from_credential(credentials)
@ -106,77 +136,151 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
]
resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream:
return self._handle_chat_response(model, credentials, prompt_messages, resp)
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
for index, r in enumerate(resp):
choices = r['choices']
def _handle_stream_chat_response() -> Generator:
for index, r in enumerate(resp):
choices = r['choices']
if not choices:
continue
choice = choices[0]
message = choice['message']
usage = None
if r.get('usage'):
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=r['usage']['prompt_tokens'],
completion_tokens=r['usage']['completion_tokens']
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=[]
),
usage=usage,
finish_reason=choice.get('finish_reason'),
),
)
def _handle_chat_response() -> LLMResult:
choices = resp['choices']
if not choices:
continue
raise ValueError("No choices found")
choice = choices[0]
message = choice['message']
usage = None
if r.get('usage'):
usage = self._calc_usage(model, credentials, r['usage'])
yield LLMResultChunk(
# parse tool calls
tool_calls = []
if message['tool_calls']:
for call in message['tool_calls']:
tool_call = AssistantPromptMessage.ToolCall(
id=call['function']['name'],
type=call['type'],
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=call['function']['name'],
arguments=call['function']['arguments']
)
)
tool_calls.append(tool_call)
usage = resp['usage']
return LLMResult(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=[]
),
usage=usage,
finish_reason=choice.get('finish_reason'),
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=tool_calls,
),
usage=self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=usage['prompt_tokens'],
completion_tokens=usage['completion_tokens']
),
)
def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
choices = resp['choices']
if not choices:
return
choice = choices[0]
message = choice['message']
if not stream:
return _handle_chat_response()
return _handle_stream_chat_response()
# parse tool calls
tool_calls = []
if message['tool_calls']:
for call in message['tool_calls']:
tool_call = AssistantPromptMessage.ToolCall(
id=call['function']['name'],
type=call['type'],
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=call['function']['name'],
arguments=call['function']['arguments']
)
def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
client = ArkClientV3.from_credentials(credentials)
req_params = get_v3_req_params(credentials, model_parameters, stop)
if tools:
req_params['tools'] = tools
def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator:
for chunk in chunks:
if not chunk.choices:
continue
choice = chunk.choices[0]
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=choice.index,
message=AssistantPromptMessage(
content=choice.delta.content,
tool_calls=[]
),
usage=self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens
) if chunk.usage else None,
finish_reason=choice.finish_reason,
),
)
tool_calls.append(tool_call)
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=message['content'] if message['content'] else '',
tool_calls=tool_calls,
),
usage=self._calc_usage(model, credentials, resp['usage']),
)
def _handle_chat_response(resp: ChatCompletion) -> LLMResult:
choice = resp.choices[0]
message = choice.message
# parse tool calls
tool_calls = []
if message.tool_calls:
for call in message.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=call.id,
type=call.type,
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=call.function.name,
arguments=call.function.arguments
)
)
tool_calls.append(tool_call)
def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage:
return self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=usage['prompt_tokens'],
completion_tokens=usage['completion_tokens']
)
usage = resp.usage
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=message.content if message.content else "",
tool_calls=tool_calls,
),
usage=self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens
),
)
if not stream:
resp = client.chat(prompt_messages, **req_params)
return _handle_chat_response(resp)
chunks = client.stream_chat(prompt_messages, **req_params)
return _handle_stream_chat_response(chunks)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
model_config = get_model_config(credentials)
rules = [
ParameterRule(
name='temperature',
@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
use_template='presence_penalty',
label=I18nObject(
en_US='Presence Penalty',
zh_Hans= '存在惩罚',
zh_Hans='存在惩罚',
),
min=-2.0,
max=2.0,
@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
type=ParameterType.FLOAT,
use_template='frequency_penalty',
label=I18nObject(
en_US= 'Frequency Penalty',
zh_Hans= '频率惩罚',
en_US='Frequency Penalty',
zh_Hans='频率惩罚',
),
min=-2.0,
max=2.0,
@ -245,7 +349,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
entity = AIModelEntity(
model=model,
label=I18nObject(

View File

@ -5,10 +5,11 @@ from core.model_runtime.entities.model_entities import ModelFeature
class ModelProperties(BaseModel):
context_size: int
max_tokens: int
context_size: int
max_tokens: int
mode: LLMMode
class ModelConfig(BaseModel):
properties: ModelProperties
features: list[ModelFeature]
@ -24,23 +25,23 @@ configs: dict[str, ModelConfig] = {
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Skylark2-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Llama3-8B': ModelConfig(
@ -77,23 +78,24 @@ configs: dict[str, ModelConfig] = {
)
}
def get_model_config(credentials: dict)->ModelConfig:
def get_model_config(credentials: dict) -> ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = configs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_tokens=int(credentials.get('max_tokens', 0)),
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
mode=LLMMode.value_of(credentials.get('mode', 'chat')),
),
features=[]
)
return model_configs
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None=None):
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None = None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
@ -116,8 +118,36 @@ def get_v2_req_params(credentials: dict, model_parameters: dict,
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params
return req_params
def get_v3_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None = None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
req_params['max_tokens'] = model_configs.properties.max_tokens
# model parameters
if model_parameters.get('max_tokens'):
req_params['max_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params

View File

@ -2,26 +2,29 @@ from pydantic import BaseModel
class ModelProperties(BaseModel):
context_size: int
max_chunks: int
context_size: int
max_chunks: int
class ModelConfig(BaseModel):
properties: ModelProperties
ModelConfigs = {
'Doubao-embedding': ModelConfig(
properties=ModelProperties(context_size=4096, max_chunks=1)
properties=ModelProperties(context_size=4096, max_chunks=32)
),
}
def get_model_config(credentials: dict)->ModelConfig:
def get_model_config(credentials: dict) -> ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = ModelConfigs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_chunks=int(credentials.get('max_chunks', 0)),
)
)
return model_configs
return model_configs

View File

@ -22,16 +22,17 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
from core.model_runtime.model_providers.volcengine_maas.errors import (
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors,
BadRequestErrors,
ConnectionErrors,
MaasException,
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
@ -51,6 +52,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
:param user: unique user id
:return: embeddings result
"""
if ArkClientV3.is_legacy(credentials):
return self._generate_v2(model, credentials, texts, user)
return self._generate_v3(model, credentials, texts, user)
def _generate_v2(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
client = MaaSClient.from_credential(credentials)
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
@ -65,6 +74,23 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
return result
def _generate_v3(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
client = ArkClientV3.from_credentials(credentials)
resp = client.embeddings(texts)
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=resp.usage.total_tokens)
result = TextEmbeddingResult(
model=model,
embeddings=[v.embedding for v in resp.data],
usage=usage
)
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -88,11 +114,22 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
:param credentials: model credentials
:return:
"""
if ArkClientV3.is_legacy(credentials):
return self._validate_credentials_v2(model, credentials)
return self._validate_credentials_v3(model, credentials)
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
except MaasException as e:
raise CredentialsValidateFailedError(e.message)
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
try:
self._invoke(model=model, credentials=credentials, texts=['ping'])
except Exception as e:
raise CredentialsValidateFailedError(e)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
@ -116,9 +153,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
generate custom model entities from credentials
"""
model_config = get_model_config(credentials)
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
model_properties = {
ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size,
ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks
}
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),

37
api/poetry.lock generated
View File

@ -6143,6 +6143,19 @@ files = [
{file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
{file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
{file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
{file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
{file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
{file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
{file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
{file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
]
[package.dependencies]
@ -8854,6 +8867,28 @@ files = [
{file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"},
]
[[package]]
name = "volcengine-python-sdk"
version = "1.0.98"
description = "Volcengine SDK for Python"
optional = false
python-versions = "*"
files = [
{file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"},
]
[package.dependencies]
anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""}
certifi = ">=2017.4.17"
httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""}
pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""}
python-dateutil = ">=2.1"
six = ">=1.10"
urllib3 = ">=1.23"
[package.extras]
ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"]
[[package]]
name = "watchfiles"
version = "0.23.0"
@ -9634,4 +9669,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02"
content-hash = "04f970820de691f40fc9fb30f5ff0618b0f1a04d3315b14467fb88e475fa1243"

View File

@ -191,6 +191,7 @@ zhipuai = "1.0.7"
# Related transparent dependencies with pinned verion
# required by main implementations
############################################################
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
[tool.poetry.group.indriect.dependencies]
kaleido = "0.2.1"
rank-bm25 = "~0.2.2"