From efc136cce5469899c4582bf8db260c74bcf74830 Mon Sep 17 00:00:00 2001 From: sino Date: Sat, 24 Aug 2024 19:29:45 +0800 Subject: [PATCH] feat: Introduce Ark SDK v3 and ensure compatibility with models of SDK v2 (#7579) Co-authored-by: crazywoola <427733928@qq.com> --- .../model_providers/volcengine_maas/client.py | 247 ++++++++++------- .../volcengine_maas/legacy/__init__.py | 0 .../volcengine_maas/legacy/client.py | 134 ++++++++++ .../volcengine_maas/{ => legacy}/errors.py | 2 +- .../{ => legacy}/volc_sdk/__init__.py | 0 .../{ => legacy}/volc_sdk/base/__init__.py | 0 .../{ => legacy}/volc_sdk/base/auth.py | 0 .../{ => legacy}/volc_sdk/base/service.py | 0 .../{ => legacy}/volc_sdk/base/util.py | 0 .../{ => legacy}/volc_sdk/common.py | 0 .../{ => legacy}/volc_sdk/maas.py | 0 .../volcengine_maas/llm/llm.py | 252 +++++++++++++----- .../volcengine_maas/llm/models.py | 58 +++- .../volcengine_maas/text_embedding/models.py | 15 +- .../text_embedding/text_embedding.py | 50 +++- api/poetry.lock | 37 ++- api/pyproject.toml | 1 + 17 files changed, 604 insertions(+), 192 deletions(-) create mode 100644 api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py create mode 100644 api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py rename api/core/model_runtime/model_providers/volcengine_maas/{ => legacy}/errors.py (97%) rename api/core/model_runtime/model_providers/volcengine_maas/{ => legacy}/volc_sdk/__init__.py (100%) rename api/core/model_runtime/model_providers/volcengine_maas/{ => legacy}/volc_sdk/base/__init__.py (100%) rename api/core/model_runtime/model_providers/volcengine_maas/{ => legacy}/volc_sdk/base/auth.py (100%) rename api/core/model_runtime/model_providers/volcengine_maas/{ => legacy}/volc_sdk/base/service.py (100%) rename api/core/model_runtime/model_providers/volcengine_maas/{ => legacy}/volc_sdk/base/util.py (100%) rename api/core/model_runtime/model_providers/volcengine_maas/{ => legacy}/volc_sdk/common.py (100%) rename api/core/model_runtime/model_providers/volcengine_maas/{ => legacy}/volc_sdk/maas.py (100%) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 471cb3c94e..5100494e58 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,6 +1,25 @@ import re -from collections.abc import Callable, Generator -from typing import cast +from collections.abc import Generator +from typing import Optional, cast + +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionChunk, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -12,123 +31,171 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService -class MaaSClient(MaasService): - def __init__(self, host: str, region: str): +class ArkClientV3: + endpoint_id: Optional[str] = None + ark: Optional[Ark] = None + + def __init__(self, *args, **kwargs): + self.ark = Ark(*args, **kwargs) self.endpoint_id = None - super().__init__(host, region) - - def set_endpoint_id(self, endpoint_id: str): - self.endpoint_id = endpoint_id - - @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] - - client = cls(host, region) - client.set_endpoint_id(endpoint_id) - client.set_ak(ak) - client.set_sk(sk) - return client - - def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: - req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], - **extra_model_kwargs, - } - if not stream: - return super().chat( - self.endpoint_id, - req, - ) - return super().stream_chat( - self.endpoint_id, - req, - ) - - def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } - return super().embeddings(self.endpoint_id, req) @staticmethod - def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + def is_legacy(credentials: dict) -> bool: + if ArkClientV3.is_compatible_with_legacy(credentials): + return False + sdk_version = credentials.get("sdk_version", "v2") + return sdk_version != "v3" + + @staticmethod + def is_compatible_with_legacy(credentials: dict) -> bool: + sdk_version = credentials.get("sdk_version") + endpoint = credentials.get("api_endpoint_host") + return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com" + + @classmethod + def from_credentials(cls, credentials): + """Initialize the client using the credentials provided.""" + args = { + "base_url": credentials['api_endpoint_host'], + "region": credentials['volc_region'], + "ak": credentials['volc_access_key_id'], + "sk": credentials['volc_secret_access_key'], + } + if cls.is_compatible_with_legacy(credentials): + args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3" + + client = ArkClientV3( + **args + ) + client.endpoint_id = credentials['endpoint_id'] + return client + + @staticmethod + def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam: + """Converts a PromptMessage to a ChatCompletionMessageParam""" if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + content = message.content else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + content.append(ChatCompletionContentPartTextParam( + text=message_content.text, + type='text', + )) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast( ImagePromptMessageContent, message_content) image_data = re.sub( r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, - } - }) - - message_dict = {'role': ChatRole.USER, 'content': content} + content.append(ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type='image_url', + )) + message_dict = ChatCompletionUserMessageParam( + role='user', + content=content + ) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} - if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls + message_dict = ChatCompletionAssistantMessageParam( + content=message.content, + role='assistant', + tool_calls=None if not message.tool_calls else [ + ChatCompletionMessageToolCallParam( + id=call.id, + function=Function( + name=call.function.name, + arguments=call.function.arguments + ), + type='function' + ) for call in message.tool_calls ] + ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = ChatCompletionSystemMessageParam( + content=message.content, + role='system' + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = ChatCompletionToolMessageParam( + content=message.content, + role='tool', + tool_call_id=message.tool_call_id + ) else: raise ValueError(f"Got unknown PromptMessage type {message}") return message_dict @staticmethod - def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: - try: - resp = fn() - except MaasException as e: - raise wrap_error(e) + def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: + return ChatCompletionToolParam( + type='function', + function=FunctionDefinition( + name=message.name, + description=message.description, + parameters=message.parameters, + ) + ) - return resp + def chat(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: + """Block chat""" + return self.ark.chat.completions.create( + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) - @staticmethod - def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): - return { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - } - } + def stream_chat(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: + """Stream chat""" + chunks = self.ark.chat.completions.create( + stream=True, + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) + for chunk in chunks: + if not chunk.choices: + continue + yield chunk + + def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse: + return self.ark.embeddings.create(model=self.endpoint_id, input=texts) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py new file mode 100644 index 0000000000..1978c11680 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -0,0 +1,134 @@ +import re +from collections.abc import Callable, Generator +from typing import cast + +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService + + +class MaaSClient(MaasService): + def __init__(self, host: str, region: str): + self.endpoint_id = None + super().__init__(host, region) + + def set_endpoint_id(self, endpoint_id: str): + self.endpoint_id = endpoint_id + + @classmethod + def from_credential(cls, credentials: dict) -> 'MaaSClient': + host = credentials['api_endpoint_host'] + region = credentials['volc_region'] + ak = credentials['volc_access_key_id'] + sk = credentials['volc_secret_access_key'] + endpoint_id = credentials['endpoint_id'] + + client = cls(host, region) + client.set_endpoint_id(endpoint_id) + client.set_ak(ak) + client.set_sk(sk) + return client + + def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: + req = { + 'parameters': params, + 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + **extra_model_kwargs, + } + if not stream: + return super().chat( + self.endpoint_id, + req, + ) + return super().stream_chat( + self.endpoint_id, + req, + ) + + def embeddings(self, texts: list[str]) -> dict: + req = { + 'input': texts + } + return super().embeddings(self.endpoint_id, req) + + @staticmethod + def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": ChatRole.USER, + "content": message.content} + else: + content = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + raise ValueError( + 'Content object type only support image_url') + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast( + ImagePromptMessageContent, message_content) + image_data = re.sub( + r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + content.append({ + 'type': 'image_url', + 'image_url': { + 'url': '', + 'image_bytes': image_data, + 'detail': message_content.detail, + } + }) + + message_dict = {'role': ChatRole.USER, 'content': content} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {'role': ChatRole.ASSISTANT, + 'content': message.content} + if message.tool_calls: + message_dict['tool_calls'] = [ + { + 'name': call.function.name, + 'arguments': call.function.arguments + } for call in message.tool_calls + ] + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {'role': ChatRole.SYSTEM, + 'content': message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {'role': ChatRole.FUNCTION, + 'content': message.content, + 'name': message.tool_call_id} + else: + raise ValueError(f"Got unknown PromptMessage type {message}") + + return message_dict + + @staticmethod + def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: + try: + resp = fn() + except MaasException as e: + raise wrap_error(e) + + return resp + + @staticmethod + def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py similarity index 97% rename from api/core/model_runtime/model_providers/volcengine_maas/errors.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 63397a456e..21ffaf1258 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -1,4 +1,4 @@ -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException class ClientSDKRequestError(MaasException): diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index add5822bef..996c66e604 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -1,8 +1,10 @@ import logging from collections.abc import Generator +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk + from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasException, RateLimitErrors, ServerUnavailableErrors, ) from core.model_runtime.model_providers.volcengine_maas.llm.models import ( get_model_config, get_v2_req_params, + get_v3_req_params, ) -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException logger = logging.getLogger(__name__) @@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: - return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials """ - # ping + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(credentials) + return self._validate_credentials_v3(credentials) + + @staticmethod + def _validate_credentials_v2(credentials: dict) -> None: client = MaaSClient.from_credential(credentials) try: client.chat( @@ -70,18 +81,24 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): except MaasException as e: raise CredentialsValidateFailedError(e.message) + @staticmethod + def _validate_credentials_v3(credentials: dict) -> None: + client = ArkClientV3.from_credentials(credentials) + try: + client.chat(max_tokens=16, temperature=0.7, top_p=0.9, + messages=[UserPromptMessage(content='ping\nAnswer: ')], ) + except Exception as e: + raise CredentialsValidateFailedError(e) + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None) -> int: - if len(prompt_messages) == 0: + if ArkClientV3.is_legacy(credentials): + return self._get_num_tokens_v2(prompt_messages) + return self._get_num_tokens_v3(prompt_messages) + + def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: return 0 - return self._num_tokens_from_messages(prompt_messages) - - def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: - """ - Calculate num tokens. - - :param messages: messages - """ num_tokens = 0 messages_dict = [ MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] @@ -92,9 +109,22 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): return num_tokens - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: + return 0 + num_tokens = 0 + messages_dict = [ + ArkClientV3.convert_prompt_message(m) for m in messages] + for message in messages_dict: + for key, value in message.items(): + num_tokens += self._get_num_tokens_by_gpt2(str(key)) + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + return num_tokens + + def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) @@ -106,77 +136,151 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): ] resp = MaaSClient.wrap_exception( lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) - if not stream: - return self._handle_chat_response(model, credentials, prompt_messages, resp) - return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) - def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator: - for index, r in enumerate(resp): - choices = r['choices'] + def _handle_stream_chat_response() -> Generator: + for index, r in enumerate(resp): + choices = r['choices'] + if not choices: + continue + choice = choices[0] + message = choice['message'] + usage = None + if r.get('usage'): + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=r['usage']['prompt_tokens'], + completion_tokens=r['usage']['completion_tokens'] + ) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=[] + ), + usage=usage, + finish_reason=choice.get('finish_reason'), + ), + ) + + def _handle_chat_response() -> LLMResult: + choices = resp['choices'] if not choices: - continue + raise ValueError("No choices found") + choice = choices[0] message = choice['message'] - usage = None - if r.get('usage'): - usage = self._calc_usage(model, credentials, r['usage']) - yield LLMResultChunk( + + # parse tool calls + tool_calls = [] + if message['tool_calls']: + for call in message['tool_calls']: + tool_call = AssistantPromptMessage.ToolCall( + id=call['function']['name'], + type=call['type'], + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call['function']['name'], + arguments=call['function']['arguments'] + ) + ) + tool_calls.append(tool_call) + + usage = resp['usage'] + return LLMResult( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] - ), - usage=usage, - finish_reason=choice.get('finish_reason'), + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=tool_calls, ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage['prompt_tokens'], + completion_tokens=usage['completion_tokens'] + ), ) - def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult: - choices = resp['choices'] - if not choices: - return - choice = choices[0] - message = choice['message'] + if not stream: + return _handle_chat_response() + return _handle_stream_chat_response() - # parse tool calls - tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: - tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + + client = ArkClientV3.from_credentials(credentials) + req_params = get_v3_req_params(credentials, model_parameters, stop) + if tools: + req_params['tools'] = tools + + def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: + for chunk in chunks: + if not chunk.choices: + continue + choice = chunk.choices[0] + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=choice.index, + message=AssistantPromptMessage( + content=choice.delta.content, + tool_calls=[] + ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens + ) if chunk.usage else None, + finish_reason=choice.finish_reason, + ), ) - tool_calls.append(tool_call) - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=tool_calls, - ), - usage=self._calc_usage(model, credentials, resp['usage']), - ) + def _handle_chat_response(resp: ChatCompletion) -> LLMResult: + choice = resp.choices[0] + message = choice.message + # parse tool calls + tool_calls = [] + if message.tool_calls: + for call in message.tool_calls: + tool_call = AssistantPromptMessage.ToolCall( + id=call.id, + type=call.type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call.function.name, + arguments=call.function.arguments + ) + ) + tool_calls.append(tool_call) - def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage: - return self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ) + usage = resp.usage + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=message.content if message.content else "", + tool_calls=tool_calls, + ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens + ), + ) + + if not stream: + resp = client.chat(prompt_messages, **req_params) + return _handle_chat_response(resp) + + chunks = client.stream_chat(prompt_messages, **req_params) + return _handle_stream_chat_response(chunks) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ used to define customizable model schema """ model_config = get_model_config(credentials) - + rules = [ ParameterRule( name='temperature', @@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): use_template='presence_penalty', label=I18nObject( en_US='Presence Penalty', - zh_Hans= '存在惩罚', + zh_Hans='存在惩罚', ), min=-2.0, max=2.0, @@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): type=ParameterType.FLOAT, use_template='frequency_penalty', label=I18nObject( - en_US= 'Frequency Penalty', - zh_Hans= '频率惩罚', + en_US='Frequency Penalty', + zh_Hans='频率惩罚', ), min=-2.0, max=2.0, @@ -245,7 +349,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model_properties = {} model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value - + entity = AIModelEntity( model=model, label=I18nObject( diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index c5e53d8955..4e2c66a066 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -5,10 +5,11 @@ from core.model_runtime.entities.model_entities import ModelFeature class ModelProperties(BaseModel): - context_size: int - max_tokens: int + context_size: int + max_tokens: int mode: LLMMode + class ModelConfig(BaseModel): properties: ModelProperties features: list[ModelFeature] @@ -24,23 +25,23 @@ configs: dict[str, ModelConfig] = { features=[ModelFeature.TOOL_CALL] ), 'Doubao-pro-32k': ModelConfig( - properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-lite-32k': ModelConfig( - properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-pro-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-lite-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Skylark2-pro-4k': ModelConfig( - properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), 'Llama3-8B': ModelConfig( @@ -77,23 +78,24 @@ configs: dict[str, ModelConfig] = { ) } -def get_model_config(credentials: dict)->ModelConfig: + +def get_model_config(credentials: dict) -> ModelConfig: base_model = credentials.get('base_model_name', '') model_configs = configs.get(base_model) if not model_configs: return ModelConfig( - properties=ModelProperties( + properties=ModelProperties( context_size=int(credentials.get('context_size', 0)), max_tokens=int(credentials.get('max_tokens', 0)), - mode= LLMMode.value_of(credentials.get('mode', 'chat')), + mode=LLMMode.value_of(credentials.get('mode', 'chat')), ), features=[] ) return model_configs -def get_v2_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None=None): +def get_v2_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) @@ -116,8 +118,36 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, if model_parameters.get('frequency_penalty'): req_params['frequency_penalty'] = model_parameters.get( 'frequency_penalty') - + if stop: req_params['stop'] = stop - return req_params \ No newline at end of file + return req_params + + +def get_v3_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None = None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params['max_tokens'] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get('max_tokens'): + req_params['max_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + + if stop: + req_params['stop'] = stop + + return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 2d8f972b94..74cf26247c 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -2,26 +2,29 @@ from pydantic import BaseModel class ModelProperties(BaseModel): - context_size: int - max_chunks: int + context_size: int + max_chunks: int + class ModelConfig(BaseModel): properties: ModelProperties + ModelConfigs = { 'Doubao-embedding': ModelConfig( - properties=ModelProperties(context_size=4096, max_chunks=1) + properties=ModelProperties(context_size=4096, max_chunks=32) ), } -def get_model_config(credentials: dict)->ModelConfig: + +def get_model_config(credentials: dict) -> ModelConfig: base_model = credentials.get('base_model_name', '') model_configs = ModelConfigs.get(base_model) if not model_configs: return ModelConfig( - properties=ModelProperties( + properties=ModelProperties( context_size=int(credentials.get('context_size', 0)), max_chunks=int(credentials.get('max_chunks', 0)), ) ) - return model_configs \ No newline at end of file + return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 8ac632369e..d54aeeb0b1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -22,16 +22,17 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasException, RateLimitErrors, ServerUnavailableErrors, ) from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): @@ -51,6 +52,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, texts, user) + + return self._generate_v3(model, credentials, texts, user) + + def _generate_v2(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) @@ -65,6 +74,23 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): return result + def _generate_v3(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + client = ArkClientV3.from_credentials(credentials) + resp = client.embeddings(texts) + + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=resp.usage.total_tokens) + + result = TextEmbeddingResult( + model=model, + embeddings=[v.embedding for v in resp.data], + usage=usage + ) + + return result + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -88,11 +114,22 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): :param credentials: model credentials :return: """ + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(model, credentials) + return self._validate_credentials_v3(model, credentials) + + def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: self._invoke(model=model, credentials=credentials, texts=['ping']) except MaasException as e: raise CredentialsValidateFailedError(e.message) + def _validate_credentials_v3(self, model: str, credentials: dict) -> None: + try: + self._invoke(model=model, credentials=credentials, texts=['ping']) + except Exception as e: + raise CredentialsValidateFailedError(e) + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -116,9 +153,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): generate custom model entities from credentials """ model_config = get_model_config(credentials) - model_properties = {} - model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size - model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks + model_properties = { + ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks + } entity = AIModelEntity( model=model, label=I18nObject(en_US=model), diff --git a/api/poetry.lock b/api/poetry.lock index 0a8919a30a..a273b3c4d9 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -6143,6 +6143,19 @@ files = [ {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] @@ -8854,6 +8867,28 @@ files = [ {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"}, ] +[[package]] +name = "volcengine-python-sdk" +version = "1.0.98" +description = "Volcengine SDK for Python" +optional = false +python-versions = "*" +files = [ + {file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"}, +] + +[package.dependencies] +anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""} +certifi = ">=2017.4.17" +httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""} +pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""} +python-dateutil = ">=2.1" +six = ">=1.10" +urllib3 = ">=1.23" + +[package.extras] +ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"] + [[package]] name = "watchfiles" version = "0.23.0" @@ -9634,4 +9669,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02" +content-hash = "04f970820de691f40fc9fb30f5ff0618b0f1a04d3315b14467fb88e475fa1243" diff --git a/api/pyproject.toml b/api/pyproject.toml index e05a51dc13..0b2369663e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -191,6 +191,7 @@ zhipuai = "1.0.7" # Related transparent dependencies with pinned verion # required by main implementations ############################################################ +volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"} [tool.poetry.group.indriect.dependencies] kaleido = "0.2.1" rank-bm25 = "~0.2.2"