From 25264e78526ccead89a9b6892056a2bcfed10427 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 20 Aug 2023 19:35:07 +0800 Subject: [PATCH] feat: add xinference embedding model support (#930) --- .../models/embedding/xinference_embedding.py | 26 ++++++++ .../providers/xinference_provider.py | 3 + .../embedding/test_xinference_embedding.py | 65 +++++++++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 api/core/model_providers/models/embedding/xinference_embedding.py create mode 100644 api/tests/integration_tests/models/embedding/test_xinference_embedding.py diff --git a/api/core/model_providers/models/embedding/xinference_embedding.py b/api/core/model_providers/models/embedding/xinference_embedding.py new file mode 100644 index 0000000000..839eeea357 --- /dev/null +++ b/api/core/model_providers/models/embedding/xinference_embedding.py @@ -0,0 +1,26 @@ +from langchain.embeddings import XinferenceEmbeddings +from replicate.exceptions import ModelError, ReplicateError + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.providers.base import BaseModelProvider +from core.model_providers.models.embedding.base import BaseEmbedding + + +class XinferenceEmbedding(BaseEmbedding): + def __init__(self, model_provider: BaseModelProvider, name: str): + credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = XinferenceEmbeddings( + **credentials, + ) + + super().__init__(model_provider, client, name) + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, (ModelError, ReplicateError)): + return LLMBadRequestError(f"Xinference embedding: {str(ex)}") + else: + return ex diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py index 5188c99f93..3152499c86 100644 --- a/api/core/model_providers/providers/xinference_provider.py +++ b/api/core/model_providers/providers/xinference_provider.py @@ -4,6 +4,7 @@ from typing import Type from langchain.llms import Xinference from core.helper import encrypter +from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.llm.xinference_model import XinferenceModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError @@ -32,6 +33,8 @@ class XinferenceProvider(BaseModelProvider): """ if model_type == ModelType.TEXT_GENERATION: model_class = XinferenceModel + elif model_type == ModelType.EMBEDDINGS: + model_class = XinferenceEmbedding else: raise NotImplementedError diff --git a/api/tests/integration_tests/models/embedding/test_xinference_embedding.py b/api/tests/integration_tests/models/embedding/test_xinference_embedding.py new file mode 100644 index 0000000000..7b7db1892d --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_xinference_embedding.py @@ -0,0 +1,65 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.xinference_provider import XinferenceProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='xinference', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_embedding_model(mocker): + model_name = 'vicuna-v1.3' + server_url = os.environ['XINFERENCE_SERVER_URL'] + model_uid = os.environ['XINFERENCE_MODEL_UID'] + model_provider = XinferenceProvider(provider=get_mock_provider()) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='xinference', + model_name=model_name, + model_type=ModelType.EMBEDDINGS.value, + encrypted_config=json.dumps({ + 'server_url': server_url, + 'model_uid': model_uid + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return XinferenceEmbedding( + model_provider=model_provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embed_documents(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model(mocker) + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) == 4096 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embed_query(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model(mocker) + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 4096