diff --git a/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml b/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml index 835a7716f7..812b51ddcd 100644 --- a/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml @@ -18,6 +18,7 @@ help: en_US: https://console.cloud.tencent.com/cam/capi supported_model_types: - llm + - text-embedding configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/__init__.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/hunyuan-text-embedding.yaml b/api/core/model_runtime/model_providers/hunyuan/text_embedding/hunyuan-text-embedding.yaml new file mode 100644 index 0000000000..ab014e4344 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/hunyuan-text-embedding.yaml @@ -0,0 +1,5 @@ +model: hunyuan-embedding +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 1 diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py new file mode 100644 index 0000000000..64d8dcf795 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -0,0 +1,173 @@ +import json +import logging +import time +from typing import Optional + +from tencentcloud.common import credential +from tencentcloud.common.exception import TencentCloudSDKException +from tencentcloud.common.profile.client_profile import ClientProfile +from tencentcloud.common.profile.http_profile import HttpProfile +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +logger = logging.getLogger(__name__) + +class HunyuanTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Hunyuan text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + + if model != 'hunyuan-embedding': + raise ValueError('Invalid model name') + + client = self._setup_hunyuan_client(credentials) + + embeddings = [] + token_usage = 0 + + for input in texts: + request = models.GetEmbeddingRequest() + params = { + "Input": input + } + request.from_json_string(json.dumps(params)) + response = client.GetEmbedding(request) + usage = response.Usage.TotalTokens + + embeddings.extend([data.Embedding for data in response.Data]) + token_usage += usage + + result = TextEmbeddingResult( + model=model, + embeddings=embeddings, + usage=self._calc_response_usage( + model=model, + credentials=credentials, + tokens=token_usage + ) + ) + + return result + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate credentials + """ + try: + client = self._setup_hunyuan_client(credentials) + + req = models.ChatCompletionsRequest() + params = { + "Model": model, + "Messages": [{ + "Role": "user", + "Content": "hello" + }], + "TopP": 1, + "Temperature": 0, + "Stream": False + } + req.from_json_string(json.dumps(params)) + client.ChatCompletions(req) + except Exception as e: + raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + + def _setup_hunyuan_client(self, credentials): + secret_id = credentials['secret_id'] + secret_key = credentials['secret_key'] + cred = credential.Credential(secret_id, secret_key) + httpProfile = HttpProfile() + httpProfile.endpoint = "hunyuan.tencentcloudapi.com" + clientProfile = ClientProfile() + clientProfile.httpProfile = httpProfile + client = hunyuan_client.HunyuanClient(cred, "", clientProfile) + return client + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeError: [TencentCloudSDKException], + } + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + # client = self._setup_hunyuan_client(credentials) + + num_tokens = 0 + for text in texts: + num_tokens += self._get_num_tokens_by_gpt2(text) + # use client.GetTokenCount to get num tokens + # request = models.GetTokenCountRequest() + # params = { + # "Prompt": text + # } + # request.from_json_string(json.dumps(params)) + # response = client.GetTokenCount(request) + # num_tokens += response.TokenCount + + return num_tokens \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py new file mode 100644 index 0000000000..7ae6c0e456 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/hunyuan/test_text_embedding.py @@ -0,0 +1,104 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.hunyuan.text_embedding.text_embedding import HunyuanTextEmbeddingModel + + +def test_validate_credentials(): + model = HunyuanTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='hunyuan-embedding', + credentials={ + 'secret_id': 'invalid_key', + 'secret_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='hunyuan-embedding', + credentials={ + 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), + 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + } + ) + + +def test_invoke_model(): + model = HunyuanTextEmbeddingModel() + + result = model.invoke( + model='hunyuan-embedding', + credentials={ + 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), + 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 6 + +def test_get_num_tokens(): + model = HunyuanTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='hunyuan-embedding', + credentials={ + 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), + 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + }, + texts=[ + "hello", + "world" + ] + ) + + assert num_tokens == 2 + +def test_max_chunks(): + model = HunyuanTextEmbeddingModel() + + result = model.invoke( + model='hunyuan-embedding', + credentials={ + 'secret_id': os.environ.get('HUNYUAN_SECRET_ID'), + 'secret_key': os.environ.get('HUNYUAN_SECRET_KEY') + }, + texts=[ + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + ] + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 22 \ No newline at end of file