mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 04:18:58 +08:00
feat: Introduce Ark SDK v3 and ensure compatibility with models of SDK v2 (#7579)
Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
parent
b035c02f78
commit
efc136cce5
@ -1,6 +1,25 @@
|
|||||||
import re
|
import re
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Generator
|
||||||
from typing import cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from volcenginesdkarkruntime import Ark
|
||||||
|
from volcenginesdkarkruntime.types.chat import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionContentPartImageParam,
|
||||||
|
ChatCompletionContentPartTextParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
)
|
||||||
|
from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL
|
||||||
|
from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function
|
||||||
|
from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse
|
||||||
|
from volcenginesdkarkruntime.types.shared_params import FunctionDefinition
|
||||||
|
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -12,123 +31,171 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
|
|
||||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
|
|
||||||
|
|
||||||
|
|
||||||
class MaaSClient(MaasService):
|
class ArkClientV3:
|
||||||
def __init__(self, host: str, region: str):
|
endpoint_id: Optional[str] = None
|
||||||
|
ark: Optional[Ark] = None
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.ark = Ark(*args, **kwargs)
|
||||||
self.endpoint_id = None
|
self.endpoint_id = None
|
||||||
super().__init__(host, region)
|
|
||||||
|
|
||||||
def set_endpoint_id(self, endpoint_id: str):
|
|
||||||
self.endpoint_id = endpoint_id
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_credential(cls, credentials: dict) -> 'MaaSClient':
|
|
||||||
host = credentials['api_endpoint_host']
|
|
||||||
region = credentials['volc_region']
|
|
||||||
ak = credentials['volc_access_key_id']
|
|
||||||
sk = credentials['volc_secret_access_key']
|
|
||||||
endpoint_id = credentials['endpoint_id']
|
|
||||||
|
|
||||||
client = cls(host, region)
|
|
||||||
client.set_endpoint_id(endpoint_id)
|
|
||||||
client.set_ak(ak)
|
|
||||||
client.set_sk(sk)
|
|
||||||
return client
|
|
||||||
|
|
||||||
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
|
|
||||||
req = {
|
|
||||||
'parameters': params,
|
|
||||||
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
|
|
||||||
**extra_model_kwargs,
|
|
||||||
}
|
|
||||||
if not stream:
|
|
||||||
return super().chat(
|
|
||||||
self.endpoint_id,
|
|
||||||
req,
|
|
||||||
)
|
|
||||||
return super().stream_chat(
|
|
||||||
self.endpoint_id,
|
|
||||||
req,
|
|
||||||
)
|
|
||||||
|
|
||||||
def embeddings(self, texts: list[str]) -> dict:
|
|
||||||
req = {
|
|
||||||
'input': texts
|
|
||||||
}
|
|
||||||
return super().embeddings(self.endpoint_id, req)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
|
def is_legacy(credentials: dict) -> bool:
|
||||||
|
if ArkClientV3.is_compatible_with_legacy(credentials):
|
||||||
|
return False
|
||||||
|
sdk_version = credentials.get("sdk_version", "v2")
|
||||||
|
return sdk_version != "v3"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_compatible_with_legacy(credentials: dict) -> bool:
|
||||||
|
sdk_version = credentials.get("sdk_version")
|
||||||
|
endpoint = credentials.get("api_endpoint_host")
|
||||||
|
return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_credentials(cls, credentials):
|
||||||
|
"""Initialize the client using the credentials provided."""
|
||||||
|
args = {
|
||||||
|
"base_url": credentials['api_endpoint_host'],
|
||||||
|
"region": credentials['volc_region'],
|
||||||
|
"ak": credentials['volc_access_key_id'],
|
||||||
|
"sk": credentials['volc_secret_access_key'],
|
||||||
|
}
|
||||||
|
if cls.is_compatible_with_legacy(credentials):
|
||||||
|
args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
|
|
||||||
|
client = ArkClientV3(
|
||||||
|
**args
|
||||||
|
)
|
||||||
|
client.endpoint_id = credentials['endpoint_id']
|
||||||
|
return client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam:
|
||||||
|
"""Converts a PromptMessage to a ChatCompletionMessageParam"""
|
||||||
if isinstance(message, UserPromptMessage):
|
if isinstance(message, UserPromptMessage):
|
||||||
message = cast(UserPromptMessage, message)
|
message = cast(UserPromptMessage, message)
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
message_dict = {"role": ChatRole.USER,
|
content = message.content
|
||||||
"content": message.content}
|
|
||||||
else:
|
else:
|
||||||
content = []
|
content = []
|
||||||
for message_content in message.content:
|
for message_content in message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
raise ValueError(
|
content.append(ChatCompletionContentPartTextParam(
|
||||||
'Content object type only support image_url')
|
text=message_content.text,
|
||||||
|
type='text',
|
||||||
|
))
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
message_content = cast(
|
message_content = cast(
|
||||||
ImagePromptMessageContent, message_content)
|
ImagePromptMessageContent, message_content)
|
||||||
image_data = re.sub(
|
image_data = re.sub(
|
||||||
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||||
content.append({
|
content.append(ChatCompletionContentPartImageParam(
|
||||||
'type': 'image_url',
|
image_url=ImageURL(
|
||||||
'image_url': {
|
url=image_data,
|
||||||
'url': '',
|
detail=message_content.detail.value,
|
||||||
'image_bytes': image_data,
|
),
|
||||||
'detail': message_content.detail,
|
type='image_url',
|
||||||
}
|
))
|
||||||
})
|
message_dict = ChatCompletionUserMessageParam(
|
||||||
|
role='user',
|
||||||
message_dict = {'role': ChatRole.USER, 'content': content}
|
content=content
|
||||||
|
)
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
message_dict = {'role': ChatRole.ASSISTANT,
|
message_dict = ChatCompletionAssistantMessageParam(
|
||||||
'content': message.content}
|
content=message.content,
|
||||||
if message.tool_calls:
|
role='assistant',
|
||||||
message_dict['tool_calls'] = [
|
tool_calls=None if not message.tool_calls else [
|
||||||
{
|
ChatCompletionMessageToolCallParam(
|
||||||
'name': call.function.name,
|
id=call.id,
|
||||||
'arguments': call.function.arguments
|
function=Function(
|
||||||
} for call in message.tool_calls
|
name=call.function.name,
|
||||||
|
arguments=call.function.arguments
|
||||||
|
),
|
||||||
|
type='function'
|
||||||
|
) for call in message.tool_calls
|
||||||
]
|
]
|
||||||
|
)
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {'role': ChatRole.SYSTEM,
|
message_dict = ChatCompletionSystemMessageParam(
|
||||||
'content': message.content}
|
content=message.content,
|
||||||
|
role='system'
|
||||||
|
)
|
||||||
elif isinstance(message, ToolPromptMessage):
|
elif isinstance(message, ToolPromptMessage):
|
||||||
message = cast(ToolPromptMessage, message)
|
message = cast(ToolPromptMessage, message)
|
||||||
message_dict = {'role': ChatRole.FUNCTION,
|
message_dict = ChatCompletionToolMessageParam(
|
||||||
'content': message.content,
|
content=message.content,
|
||||||
'name': message.tool_call_id}
|
role='tool',
|
||||||
|
tool_call_id=message.tool_call_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown PromptMessage type {message}")
|
raise ValueError(f"Got unknown PromptMessage type {message}")
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam:
|
||||||
try:
|
return ChatCompletionToolParam(
|
||||||
resp = fn()
|
type='function',
|
||||||
except MaasException as e:
|
function=FunctionDefinition(
|
||||||
raise wrap_error(e)
|
name=message.name,
|
||||||
|
description=message.description,
|
||||||
|
parameters=message.parameters,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return resp
|
def chat(self, messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
) -> ChatCompletion:
|
||||||
|
"""Block chat"""
|
||||||
|
return self.ark.chat.completions.create(
|
||||||
|
model=self.endpoint_id,
|
||||||
|
messages=[self.convert_prompt_message(message) for message in messages],
|
||||||
|
tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
|
||||||
|
stop=stop,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
def stream_chat(self, messages: list[PromptMessage],
|
||||||
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
return {
|
stop: Optional[list[str]] = None,
|
||||||
"type": "function",
|
frequency_penalty: Optional[float] = None,
|
||||||
"function": {
|
max_tokens: Optional[int] = None,
|
||||||
"name": tool.name,
|
presence_penalty: Optional[float] = None,
|
||||||
"description": tool.description,
|
top_p: Optional[float] = None,
|
||||||
"parameters": tool.parameters,
|
temperature: Optional[float] = None,
|
||||||
}
|
) -> Generator[ChatCompletionChunk]:
|
||||||
}
|
"""Stream chat"""
|
||||||
|
chunks = self.ark.chat.completions.create(
|
||||||
|
stream=True,
|
||||||
|
model=self.endpoint_id,
|
||||||
|
messages=[self.convert_prompt_message(message) for message in messages],
|
||||||
|
tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
|
||||||
|
stop=stop,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
for chunk in chunks:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse:
|
||||||
|
return self.ark.embeddings.create(model=self.endpoint_id, input=texts)
|
||||||
|
@ -0,0 +1,134 @@
|
|||||||
|
import re
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
|
||||||
|
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
|
||||||
|
|
||||||
|
|
||||||
|
class MaaSClient(MaasService):
|
||||||
|
def __init__(self, host: str, region: str):
|
||||||
|
self.endpoint_id = None
|
||||||
|
super().__init__(host, region)
|
||||||
|
|
||||||
|
def set_endpoint_id(self, endpoint_id: str):
|
||||||
|
self.endpoint_id = endpoint_id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_credential(cls, credentials: dict) -> 'MaaSClient':
|
||||||
|
host = credentials['api_endpoint_host']
|
||||||
|
region = credentials['volc_region']
|
||||||
|
ak = credentials['volc_access_key_id']
|
||||||
|
sk = credentials['volc_secret_access_key']
|
||||||
|
endpoint_id = credentials['endpoint_id']
|
||||||
|
|
||||||
|
client = cls(host, region)
|
||||||
|
client.set_endpoint_id(endpoint_id)
|
||||||
|
client.set_ak(ak)
|
||||||
|
client.set_sk(sk)
|
||||||
|
return client
|
||||||
|
|
||||||
|
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
|
||||||
|
req = {
|
||||||
|
'parameters': params,
|
||||||
|
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
|
||||||
|
**extra_model_kwargs,
|
||||||
|
}
|
||||||
|
if not stream:
|
||||||
|
return super().chat(
|
||||||
|
self.endpoint_id,
|
||||||
|
req,
|
||||||
|
)
|
||||||
|
return super().stream_chat(
|
||||||
|
self.endpoint_id,
|
||||||
|
req,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embeddings(self, texts: list[str]) -> dict:
|
||||||
|
req = {
|
||||||
|
'input': texts
|
||||||
|
}
|
||||||
|
return super().embeddings(self.endpoint_id, req)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": ChatRole.USER,
|
||||||
|
"content": message.content}
|
||||||
|
else:
|
||||||
|
content = []
|
||||||
|
for message_content in message.content:
|
||||||
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
|
raise ValueError(
|
||||||
|
'Content object type only support image_url')
|
||||||
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
|
message_content = cast(
|
||||||
|
ImagePromptMessageContent, message_content)
|
||||||
|
image_data = re.sub(
|
||||||
|
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||||
|
content.append({
|
||||||
|
'type': 'image_url',
|
||||||
|
'image_url': {
|
||||||
|
'url': '',
|
||||||
|
'image_bytes': image_data,
|
||||||
|
'detail': message_content.detail,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
message_dict = {'role': ChatRole.USER, 'content': content}
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
message_dict = {'role': ChatRole.ASSISTANT,
|
||||||
|
'content': message.content}
|
||||||
|
if message.tool_calls:
|
||||||
|
message_dict['tool_calls'] = [
|
||||||
|
{
|
||||||
|
'name': call.function.name,
|
||||||
|
'arguments': call.function.arguments
|
||||||
|
} for call in message.tool_calls
|
||||||
|
]
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message = cast(SystemPromptMessage, message)
|
||||||
|
message_dict = {'role': ChatRole.SYSTEM,
|
||||||
|
'content': message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {'role': ChatRole.FUNCTION,
|
||||||
|
'content': message.content,
|
||||||
|
'name': message.tool_call_id}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown PromptMessage type {message}")
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
||||||
|
try:
|
||||||
|
resp = fn()
|
||||||
|
except MaasException as e:
|
||||||
|
raise wrap_error(e)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters,
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
|
||||||
|
|
||||||
|
|
||||||
class ClientSDKRequestError(MaasException):
|
class ClientSDKRequestError(MaasException):
|
@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
|
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
|
||||||
from core.model_runtime.model_providers.volcengine_maas.errors import (
|
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
|
||||||
|
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||||
AuthErrors,
|
AuthErrors,
|
||||||
BadRequestErrors,
|
BadRequestErrors,
|
||||||
ConnectionErrors,
|
ConnectionErrors,
|
||||||
|
MaasException,
|
||||||
RateLimitErrors,
|
RateLimitErrors,
|
||||||
ServerUnavailableErrors,
|
ServerUnavailableErrors,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
|
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
|
||||||
get_model_config,
|
get_model_config,
|
||||||
get_v2_req_params,
|
get_v2_req_params,
|
||||||
|
get_v3_req_params,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||||
-> LLMResult | Generator:
|
-> LLMResult | Generator:
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
if ArkClientV3.is_legacy(credentials):
|
||||||
|
return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate credentials
|
Validate credentials
|
||||||
"""
|
"""
|
||||||
# ping
|
if ArkClientV3.is_legacy(credentials):
|
||||||
|
return self._validate_credentials_v2(credentials)
|
||||||
|
return self._validate_credentials_v3(credentials)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_credentials_v2(credentials: dict) -> None:
|
||||||
client = MaaSClient.from_credential(credentials)
|
client = MaaSClient.from_credential(credentials)
|
||||||
try:
|
try:
|
||||||
client.chat(
|
client.chat(
|
||||||
@ -70,18 +81,24 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
except MaasException as e:
|
except MaasException as e:
|
||||||
raise CredentialsValidateFailedError(e.message)
|
raise CredentialsValidateFailedError(e.message)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_credentials_v3(credentials: dict) -> None:
|
||||||
|
client = ArkClientV3.from_credentials(credentials)
|
||||||
|
try:
|
||||||
|
client.chat(max_tokens=16, temperature=0.7, top_p=0.9,
|
||||||
|
messages=[UserPromptMessage(content='ping\nAnswer: ')], )
|
||||||
|
except Exception as e:
|
||||||
|
raise CredentialsValidateFailedError(e)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
tools: list[PromptMessageTool] | None = None) -> int:
|
||||||
if len(prompt_messages) == 0:
|
if ArkClientV3.is_legacy(credentials):
|
||||||
|
return self._get_num_tokens_v2(prompt_messages)
|
||||||
|
return self._get_num_tokens_v3(prompt_messages)
|
||||||
|
|
||||||
|
def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int:
|
||||||
|
if len(messages) == 0:
|
||||||
return 0
|
return 0
|
||||||
return self._num_tokens_from_messages(prompt_messages)
|
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
|
|
||||||
"""
|
|
||||||
Calculate num tokens.
|
|
||||||
|
|
||||||
:param messages: messages
|
|
||||||
"""
|
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
messages_dict = [
|
messages_dict = [
|
||||||
MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
|
MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
|
||||||
@ -92,7 +109,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int:
|
||||||
|
if len(messages) == 0:
|
||||||
|
return 0
|
||||||
|
num_tokens = 0
|
||||||
|
messages_dict = [
|
||||||
|
ArkClientV3.convert_prompt_message(m) for m in messages]
|
||||||
|
for message in messages_dict:
|
||||||
|
for key, value in message.items():
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(str(key))
|
||||||
|
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
||||||
|
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||||
-> LLMResult | Generator:
|
-> LLMResult | Generator:
|
||||||
@ -106,11 +136,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
]
|
]
|
||||||
resp = MaaSClient.wrap_exception(
|
resp = MaaSClient.wrap_exception(
|
||||||
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
||||||
if not stream:
|
|
||||||
return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
|
||||||
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
|
||||||
|
|
||||||
def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
|
def _handle_stream_chat_response() -> Generator:
|
||||||
for index, r in enumerate(resp):
|
for index, r in enumerate(resp):
|
||||||
choices = r['choices']
|
choices = r['choices']
|
||||||
if not choices:
|
if not choices:
|
||||||
@ -119,7 +146,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
message = choice['message']
|
message = choice['message']
|
||||||
usage = None
|
usage = None
|
||||||
if r.get('usage'):
|
if r.get('usage'):
|
||||||
usage = self._calc_usage(model, credentials, r['usage'])
|
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||||
|
prompt_tokens=r['usage']['prompt_tokens'],
|
||||||
|
completion_tokens=r['usage']['completion_tokens']
|
||||||
|
)
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -134,10 +164,11 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
|
def _handle_chat_response() -> LLMResult:
|
||||||
choices = resp['choices']
|
choices = resp['choices']
|
||||||
if not choices:
|
if not choices:
|
||||||
return
|
raise ValueError("No choices found")
|
||||||
|
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
message = choice['message']
|
message = choice['message']
|
||||||
|
|
||||||
@ -155,6 +186,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
usage = resp['usage']
|
||||||
return LLMResult(
|
return LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -162,15 +194,87 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
content=message['content'] if message['content'] else '',
|
content=message['content'] if message['content'] else '',
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
),
|
),
|
||||||
usage=self._calc_usage(model, credentials, resp['usage']),
|
usage=self._calc_response_usage(model=model, credentials=credentials,
|
||||||
)
|
|
||||||
|
|
||||||
def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage:
|
|
||||||
return self._calc_response_usage(model=model, credentials=credentials,
|
|
||||||
prompt_tokens=usage['prompt_tokens'],
|
prompt_tokens=usage['prompt_tokens'],
|
||||||
completion_tokens=usage['completion_tokens']
|
completion_tokens=usage['completion_tokens']
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
return _handle_chat_response()
|
||||||
|
return _handle_stream_chat_response()
|
||||||
|
|
||||||
|
def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||||
|
-> LLMResult | Generator:
|
||||||
|
|
||||||
|
client = ArkClientV3.from_credentials(credentials)
|
||||||
|
req_params = get_v3_req_params(credentials, model_parameters, stop)
|
||||||
|
if tools:
|
||||||
|
req_params['tools'] = tools
|
||||||
|
|
||||||
|
def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator:
|
||||||
|
for chunk in chunks:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=choice.index,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
content=choice.delta.content,
|
||||||
|
tool_calls=[]
|
||||||
|
),
|
||||||
|
usage=self._calc_response_usage(model=model, credentials=credentials,
|
||||||
|
prompt_tokens=chunk.usage.prompt_tokens,
|
||||||
|
completion_tokens=chunk.usage.completion_tokens
|
||||||
|
) if chunk.usage else None,
|
||||||
|
finish_reason=choice.finish_reason,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_chat_response(resp: ChatCompletion) -> LLMResult:
|
||||||
|
choice = resp.choices[0]
|
||||||
|
message = choice.message
|
||||||
|
# parse tool calls
|
||||||
|
tool_calls = []
|
||||||
|
if message.tool_calls:
|
||||||
|
for call in message.tool_calls:
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=call.id,
|
||||||
|
type=call.type,
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=call.function.name,
|
||||||
|
arguments=call.function.arguments
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
usage = resp.usage
|
||||||
|
return LLMResult(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
content=message.content if message.content else "",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
),
|
||||||
|
usage=self._calc_response_usage(model=model, credentials=credentials,
|
||||||
|
prompt_tokens=usage.prompt_tokens,
|
||||||
|
completion_tokens=usage.completion_tokens
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
resp = client.chat(prompt_messages, **req_params)
|
||||||
|
return _handle_chat_response(resp)
|
||||||
|
|
||||||
|
chunks = client.stream_chat(prompt_messages, **req_params)
|
||||||
|
return _handle_stream_chat_response(chunks)
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
"""
|
"""
|
||||||
used to define customizable model schema
|
used to define customizable model schema
|
||||||
@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
use_template='presence_penalty',
|
use_template='presence_penalty',
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='Presence Penalty',
|
en_US='Presence Penalty',
|
||||||
zh_Hans= '存在惩罚',
|
zh_Hans='存在惩罚',
|
||||||
),
|
),
|
||||||
min=-2.0,
|
min=-2.0,
|
||||||
max=2.0,
|
max=2.0,
|
||||||
@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
use_template='frequency_penalty',
|
use_template='frequency_penalty',
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US= 'Frequency Penalty',
|
en_US='Frequency Penalty',
|
||||||
zh_Hans= '频率惩罚',
|
zh_Hans='频率惩罚',
|
||||||
),
|
),
|
||||||
min=-2.0,
|
min=-2.0,
|
||||||
max=2.0,
|
max=2.0,
|
||||||
|
@ -9,6 +9,7 @@ class ModelProperties(BaseModel):
|
|||||||
max_tokens: int
|
max_tokens: int
|
||||||
mode: LLMMode
|
mode: LLMMode
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
properties: ModelProperties
|
properties: ModelProperties
|
||||||
features: list[ModelFeature]
|
features: list[ModelFeature]
|
||||||
@ -24,23 +25,23 @@ configs: dict[str, ModelConfig] = {
|
|||||||
features=[ModelFeature.TOOL_CALL]
|
features=[ModelFeature.TOOL_CALL]
|
||||||
),
|
),
|
||||||
'Doubao-pro-32k': ModelConfig(
|
'Doubao-pro-32k': ModelConfig(
|
||||||
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
|
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||||
features=[ModelFeature.TOOL_CALL]
|
features=[ModelFeature.TOOL_CALL]
|
||||||
),
|
),
|
||||||
'Doubao-lite-32k': ModelConfig(
|
'Doubao-lite-32k': ModelConfig(
|
||||||
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
|
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||||
features=[ModelFeature.TOOL_CALL]
|
features=[ModelFeature.TOOL_CALL]
|
||||||
),
|
),
|
||||||
'Doubao-pro-128k': ModelConfig(
|
'Doubao-pro-128k': ModelConfig(
|
||||||
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
|
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||||
features=[ModelFeature.TOOL_CALL]
|
features=[ModelFeature.TOOL_CALL]
|
||||||
),
|
),
|
||||||
'Doubao-lite-128k': ModelConfig(
|
'Doubao-lite-128k': ModelConfig(
|
||||||
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
|
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||||
features=[ModelFeature.TOOL_CALL]
|
features=[ModelFeature.TOOL_CALL]
|
||||||
),
|
),
|
||||||
'Skylark2-pro-4k': ModelConfig(
|
'Skylark2-pro-4k': ModelConfig(
|
||||||
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
|
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||||
features=[]
|
features=[]
|
||||||
),
|
),
|
||||||
'Llama3-8B': ModelConfig(
|
'Llama3-8B': ModelConfig(
|
||||||
@ -77,7 +78,8 @@ configs: dict[str, ModelConfig] = {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_model_config(credentials: dict)->ModelConfig:
|
|
||||||
|
def get_model_config(credentials: dict) -> ModelConfig:
|
||||||
base_model = credentials.get('base_model_name', '')
|
base_model = credentials.get('base_model_name', '')
|
||||||
model_configs = configs.get(base_model)
|
model_configs = configs.get(base_model)
|
||||||
if not model_configs:
|
if not model_configs:
|
||||||
@ -85,7 +87,7 @@ def get_model_config(credentials: dict)->ModelConfig:
|
|||||||
properties=ModelProperties(
|
properties=ModelProperties(
|
||||||
context_size=int(credentials.get('context_size', 0)),
|
context_size=int(credentials.get('context_size', 0)),
|
||||||
max_tokens=int(credentials.get('max_tokens', 0)),
|
max_tokens=int(credentials.get('max_tokens', 0)),
|
||||||
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
|
mode=LLMMode.value_of(credentials.get('mode', 'chat')),
|
||||||
),
|
),
|
||||||
features=[]
|
features=[]
|
||||||
)
|
)
|
||||||
@ -93,7 +95,7 @@ def get_model_config(credentials: dict)->ModelConfig:
|
|||||||
|
|
||||||
|
|
||||||
def get_v2_req_params(credentials: dict, model_parameters: dict,
|
def get_v2_req_params(credentials: dict, model_parameters: dict,
|
||||||
stop: list[str] | None=None):
|
stop: list[str] | None = None):
|
||||||
req_params = {}
|
req_params = {}
|
||||||
# predefined properties
|
# predefined properties
|
||||||
model_configs = get_model_config(credentials)
|
model_configs = get_model_config(credentials)
|
||||||
@ -121,3 +123,31 @@ def get_v2_req_params(credentials: dict, model_parameters: dict,
|
|||||||
req_params['stop'] = stop
|
req_params['stop'] = stop
|
||||||
|
|
||||||
return req_params
|
return req_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_v3_req_params(credentials: dict, model_parameters: dict,
|
||||||
|
stop: list[str] | None = None):
|
||||||
|
req_params = {}
|
||||||
|
# predefined properties
|
||||||
|
model_configs = get_model_config(credentials)
|
||||||
|
if model_configs:
|
||||||
|
req_params['max_tokens'] = model_configs.properties.max_tokens
|
||||||
|
|
||||||
|
# model parameters
|
||||||
|
if model_parameters.get('max_tokens'):
|
||||||
|
req_params['max_tokens'] = model_parameters.get('max_tokens')
|
||||||
|
if model_parameters.get('temperature'):
|
||||||
|
req_params['temperature'] = model_parameters.get('temperature')
|
||||||
|
if model_parameters.get('top_p'):
|
||||||
|
req_params['top_p'] = model_parameters.get('top_p')
|
||||||
|
if model_parameters.get('presence_penalty'):
|
||||||
|
req_params['presence_penalty'] = model_parameters.get(
|
||||||
|
'presence_penalty')
|
||||||
|
if model_parameters.get('frequency_penalty'):
|
||||||
|
req_params['frequency_penalty'] = model_parameters.get(
|
||||||
|
'frequency_penalty')
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
req_params['stop'] = stop
|
||||||
|
|
||||||
|
return req_params
|
||||||
|
@ -5,16 +5,19 @@ class ModelProperties(BaseModel):
|
|||||||
context_size: int
|
context_size: int
|
||||||
max_chunks: int
|
max_chunks: int
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
properties: ModelProperties
|
properties: ModelProperties
|
||||||
|
|
||||||
|
|
||||||
ModelConfigs = {
|
ModelConfigs = {
|
||||||
'Doubao-embedding': ModelConfig(
|
'Doubao-embedding': ModelConfig(
|
||||||
properties=ModelProperties(context_size=4096, max_chunks=1)
|
properties=ModelProperties(context_size=4096, max_chunks=32)
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_model_config(credentials: dict)->ModelConfig:
|
|
||||||
|
def get_model_config(credentials: dict) -> ModelConfig:
|
||||||
base_model = credentials.get('base_model_name', '')
|
base_model = credentials.get('base_model_name', '')
|
||||||
model_configs = ModelConfigs.get(base_model)
|
model_configs = ModelConfigs.get(base_model)
|
||||||
if not model_configs:
|
if not model_configs:
|
||||||
|
@ -22,16 +22,17 @@ from core.model_runtime.errors.invoke import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
|
from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
|
||||||
from core.model_runtime.model_providers.volcengine_maas.errors import (
|
from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
|
||||||
|
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||||
AuthErrors,
|
AuthErrors,
|
||||||
BadRequestErrors,
|
BadRequestErrors,
|
||||||
ConnectionErrors,
|
ConnectionErrors,
|
||||||
|
MaasException,
|
||||||
RateLimitErrors,
|
RateLimitErrors,
|
||||||
ServerUnavailableErrors,
|
ServerUnavailableErrors,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
|
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
|
||||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
|
||||||
|
|
||||||
|
|
||||||
class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||||
@ -51,6 +52,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
|
if ArkClientV3.is_legacy(credentials):
|
||||||
|
return self._generate_v2(model, credentials, texts, user)
|
||||||
|
|
||||||
|
return self._generate_v3(model, credentials, texts, user)
|
||||||
|
|
||||||
|
def _generate_v2(self, model: str, credentials: dict,
|
||||||
|
texts: list[str], user: Optional[str] = None) \
|
||||||
|
-> TextEmbeddingResult:
|
||||||
client = MaaSClient.from_credential(credentials)
|
client = MaaSClient.from_credential(credentials)
|
||||||
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
|
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
|
||||||
|
|
||||||
@ -65,6 +74,23 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _generate_v3(self, model: str, credentials: dict,
|
||||||
|
texts: list[str], user: Optional[str] = None) \
|
||||||
|
-> TextEmbeddingResult:
|
||||||
|
client = ArkClientV3.from_credentials(credentials)
|
||||||
|
resp = client.embeddings(texts)
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model, credentials=credentials, tokens=resp.usage.total_tokens)
|
||||||
|
|
||||||
|
result = TextEmbeddingResult(
|
||||||
|
model=model,
|
||||||
|
embeddings=[v.embedding for v in resp.data],
|
||||||
|
usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
@ -88,11 +114,22 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if ArkClientV3.is_legacy(credentials):
|
||||||
|
return self._validate_credentials_v2(model, credentials)
|
||||||
|
return self._validate_credentials_v3(model, credentials)
|
||||||
|
|
||||||
|
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
|
||||||
try:
|
try:
|
||||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||||
except MaasException as e:
|
except MaasException as e:
|
||||||
raise CredentialsValidateFailedError(e.message)
|
raise CredentialsValidateFailedError(e.message)
|
||||||
|
|
||||||
|
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
|
||||||
|
try:
|
||||||
|
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||||
|
except Exception as e:
|
||||||
|
raise CredentialsValidateFailedError(e)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
"""
|
"""
|
||||||
@ -116,9 +153,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
generate custom model entities from credentials
|
generate custom model entities from credentials
|
||||||
"""
|
"""
|
||||||
model_config = get_model_config(credentials)
|
model_config = get_model_config(credentials)
|
||||||
model_properties = {}
|
model_properties = {
|
||||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
|
ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size,
|
||||||
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
|
ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks
|
||||||
|
}
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(en_US=model),
|
label=I18nObject(en_US=model),
|
||||||
|
37
api/poetry.lock
generated
37
api/poetry.lock
generated
@ -6143,6 +6143,19 @@ files = [
|
|||||||
{file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
|
{file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
|
||||||
{file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
|
{file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
|
||||||
{file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
|
{file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
|
||||||
|
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
|
||||||
|
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
|
||||||
|
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
|
||||||
|
{file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
|
||||||
|
{file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
|
||||||
|
{file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
|
||||||
|
{file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
|
||||||
|
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
|
||||||
|
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
|
||||||
|
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
|
||||||
|
{file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
|
||||||
|
{file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
|
||||||
|
{file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -8854,6 +8867,28 @@ files = [
|
|||||||
{file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"},
|
{file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "volcengine-python-sdk"
|
||||||
|
version = "1.0.98"
|
||||||
|
description = "Volcengine SDK for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""}
|
||||||
|
certifi = ">=2017.4.17"
|
||||||
|
httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""}
|
||||||
|
pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""}
|
||||||
|
python-dateutil = ">=2.1"
|
||||||
|
six = ">=1.10"
|
||||||
|
urllib3 = ">=1.23"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "watchfiles"
|
name = "watchfiles"
|
||||||
version = "0.23.0"
|
version = "0.23.0"
|
||||||
@ -9634,4 +9669,4 @@ cffi = ["cffi (>=1.11)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02"
|
content-hash = "04f970820de691f40fc9fb30f5ff0618b0f1a04d3315b14467fb88e475fa1243"
|
||||||
|
@ -191,6 +191,7 @@ zhipuai = "1.0.7"
|
|||||||
# Related transparent dependencies with pinned verion
|
# Related transparent dependencies with pinned verion
|
||||||
# required by main implementations
|
# required by main implementations
|
||||||
############################################################
|
############################################################
|
||||||
|
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
|
||||||
[tool.poetry.group.indriect.dependencies]
|
[tool.poetry.group.indriect.dependencies]
|
||||||
kaleido = "0.2.1"
|
kaleido = "0.2.1"
|
||||||
rank-bm25 = "~0.2.2"
|
rank-bm25 = "~0.2.2"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user