From 5b32f2e0ddba615b089926ac558d82aeb1f641bf Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Fri, 9 Aug 2024 19:12:13 +0800 Subject: [PATCH] Feat: Add model provider Text Embedding Inference for embedding and rerank (#7132) --- .../huggingface_tei/__init__.py | 0 .../huggingface_tei/huggingface_tei.py | 11 + .../huggingface_tei/huggingface_tei.yaml | 36 ++++ .../huggingface_tei/rerank/__init__.py | 0 .../huggingface_tei/rerank/rerank.py | 137 ++++++++++++ .../huggingface_tei/tei_helper.py | 183 ++++++++++++++++ .../text_embedding/__init__.py | 0 .../text_embedding/text_embedding.py | 204 ++++++++++++++++++ api/pyproject.toml | 2 + .../model_runtime/__mock/huggingface_tei.py | 94 ++++++++ .../model_runtime/huggingface_tei/__init__.py | 0 .../huggingface_tei/test_embeddings.py | 72 +++++++ .../huggingface_tei/test_rerank.py | 76 +++++++ 13 files changed, 815 insertions(+) create mode 100644 api/core/model_runtime/model_providers/huggingface_tei/__init__.py create mode 100644 api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py create mode 100644 api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml create mode 100644 api/core/model_runtime/model_providers/huggingface_tei/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py create mode 100644 api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py create mode 100644 api/core/model_runtime/model_providers/huggingface_tei/text_embedding/__init__.py create mode 100644 api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py create mode 100644 api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py create mode 100644 api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py create mode 100644 api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py diff --git a/api/core/model_runtime/model_providers/huggingface_tei/__init__.py b/api/core/model_runtime/model_providers/huggingface_tei/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py new file mode 100644 index 0000000000..9454466250 --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py @@ -0,0 +1,11 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class HuggingfaceTeiProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml new file mode 100644 index 0000000000..f3a912d84d --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml @@ -0,0 +1,36 @@ +provider: huggingface_tei +label: + en_US: Text Embedding Inference +description: + en_US: A blazing fast inference solution for text embeddings models. + zh_Hans: 用于文本嵌入模型的超快速推理解决方案。 +background: "#FFF8DC" +help: + title: + en_US: How to deploy Text Embedding Inference + zh_Hans: 如何部署 Text Embedding Inference + url: + en_US: https://github.com/huggingface/text-embeddings-inference +supported_model_types: + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 + en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/__init__.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py new file mode 100644 index 0000000000..34013426de --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -0,0 +1,137 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper + + +class HuggingfaceTeiRerankModel(RerankModel): + """ + Model class for Text Embedding Inference rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + server_url = credentials['server_url'] + + if server_url.endswith('/'): + server_url = server_url[:-1] + + try: + results = TeiHelper.invoke_rerank(server_url, query, docs) + + rerank_documents = [] + for result in results: + rerank_document = RerankDocument( + index=result['index'], + text=result['text'], + score=result['score'], + ) + if score_threshold is None or result['score'] >= score_threshold: + rerank_documents.append(rerank_document) + if top_n is not None and len(rerank_documents) >= top_n: + break + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + server_url = credentials['server_url'] + extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) + if extra_args.model_type != 'reranker': + raise CredentialsValidateFailedError('Current model is not a rerank model') + + credentials['context_size'] = extra_args.max_input_length + + self.invoke( + model=model, + credentials=credentials, + query='Whose kasumi', + docs=[ + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', + 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', + 'and she leads a team named PopiParty.', + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + }, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py new file mode 100644 index 0000000000..2aa785c89d --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -0,0 +1,183 @@ +from threading import Lock +from time import time +from typing import Optional + +import httpx +from requests.adapters import HTTPAdapter +from requests.exceptions import ConnectionError, MissingSchema, Timeout +from requests.sessions import Session +from yarl import URL + + +class TeiModelExtraParameter: + model_type: str + max_input_length: int + max_client_batch_size: int + + def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None: + self.model_type = model_type + self.max_input_length = max_input_length + self.max_client_batch_size = max_client_batch_size + + +cache = {} +cache_lock = Lock() + + +class TeiHelper: + @staticmethod + def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: + TeiHelper._clean_cache() + with cache_lock: + if model_name not in cache: + cache[model_name] = { + 'expires': time() + 300, + 'value': TeiHelper._get_tei_extra_parameter(server_url), + } + return cache[model_name]['value'] + + @staticmethod + def _clean_cache() -> None: + try: + with cache_lock: + expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + for model_uid in expired_keys: + del cache[model_uid] + except RuntimeError as e: + pass + + @staticmethod + def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: + """ + get tei model extra parameter like model_type, max_input_length, max_batch_requests + """ + + url = str(URL(server_url) / 'info') + + # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + session = Session() + session.mount('http://', HTTPAdapter(max_retries=3)) + session.mount('https://', HTTPAdapter(max_retries=3)) + + try: + response = session.get(url, timeout=10) + except (MissingSchema, ConnectionError, Timeout) as e: + raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') + if response.status_code != 200: + raise RuntimeError( + f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' + ) + + response_json = response.json() + + model_type = response_json.get('model_type', {}) + if len(model_type.keys()) < 1: + raise RuntimeError('model_type is empty') + model_type = list(model_type.keys())[0] + if model_type not in ['embedding', 'reranker']: + raise RuntimeError(f'invalid model_type: {model_type}') + + max_input_length = response_json.get('max_input_length', 512) + max_client_batch_size = response_json.get('max_client_batch_size', 1) + + return TeiModelExtraParameter( + model_type=model_type, + max_input_length=max_input_length, + max_client_batch_size=max_client_batch_size + ) + + @staticmethod + def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: + """ + Invoke tokenize endpoint + + Example response: + [ + [ + { + "id": 0, + "text": "", + "special": true, + "start": null, + "stop": null + }, + { + "id": 7704, + "text": "str", + "special": false, + "start": 0, + "stop": 3 + }, + < MORE TOKENS > + ] + ] + + :param server_url: server url + :param texts: texts to tokenize + """ + resp = httpx.post( + f'{server_url}/tokenize', + json={'inputs': texts}, + ) + resp.raise_for_status() + return resp.json() + + @staticmethod + def invoke_embeddings(server_url: str, texts: list[str]) -> dict: + """ + Invoke embeddings endpoint + + Example response: + { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [...], + "index": 0 + } + ], + "model": "MODEL_NAME", + "usage": { + "prompt_tokens": 3, + "total_tokens": 3 + } + } + + :param server_url: server url + :param texts: texts to embed + """ + # Use OpenAI compatible API here, which has usage tracking + resp = httpx.post( + f'{server_url}/v1/embeddings', + json={'input': texts}, + ) + resp.raise_for_status() + return resp.json() + + @staticmethod + def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: + """ + Invoke rerank endpoint + + Example response: + [ + { + "index": 0, + "text": "Deep Learning is ...", + "score": 0.9950755 + } + ] + + :param server_url: server url + :param texts: texts to rerank + :param candidates: candidates to rerank + """ + params = {'query': query, 'texts': docs, 'return_text': True} + + response = httpx.post( + server_url + '/rerank', + json=params, + ) + response.raise_for_status() + return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/__init__.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py new file mode 100644 index 0000000000..6897b87f6d --- /dev/null +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -0,0 +1,204 @@ +import time +from typing import Optional + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper + + +class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Text Embedding Inference text embedding model. + """ + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + credentials should be like: + { + 'server_url': 'server url', + 'model_uid': 'model uid', + } + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + server_url = credentials['server_url'] + + if server_url.endswith('/'): + server_url = server_url[:-1] + + + # get model properties + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + + inputs = [] + indices = [] + used_tokens = 0 + + # get tokenized results from TEI + batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) + + for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): + + # Check if the number of tokens is larger than the context size + num_tokens = len(tokenize_result) + + if num_tokens >= context_size: + # Find the best cutoff point + pre_special_token_count = 0 + for token in tokenize_result: + if token['special']: + pre_special_token_count += 1 + else: + break + rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count + + # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit + token_cutoff = context_size - rest_special_token_count - 20 + + # Find the cutoff index + cutpoint_token = tokenize_result[token_cutoff] + cutoff = cutpoint_token['start'] + + inputs.append(text[0: cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + try: + used_tokens = 0 + for i in _iter: + iter_texts = inputs[i : i + max_chunks] + results = TeiHelper.invoke_embeddings(server_url, iter_texts) + embeddings = results['data'] + embeddings = [embedding['embedding'] for embedding in embeddings] + batched_embeddings.extend(embeddings) + + usage = results['usage'] + used_tokens += usage['total_tokens'] + except RuntimeError as e: + raise InvokeServerUnavailableError(str(e)) + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + num_tokens = 0 + server_url = credentials['server_url'] + + if server_url.endswith('/'): + server_url = server_url[:-1] + + batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) + num_tokens = sum(len(tokens) for tokens in batch_tokens) + return num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + server_url = credentials['server_url'] + extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) + print(extra_args) + if extra_args.model_type != 'embedding': + raise CredentialsValidateFailedError('Current model is not a embedding model') + + credentials['context_size'] = extra_args.max_input_length + credentials['max_chunks'] = extra_args.max_client_batch_size + self._invoke(model=model, credentials=credentials, texts=['ping']) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + }, + parameter_rules=[], + ) + + return entity diff --git a/api/pyproject.toml b/api/pyproject.toml index 15f9aab640..c08c109913 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -93,6 +93,8 @@ CODE_MAX_STRING_LENGTH = "80000" CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" CODE_EXECUTION_API_KEY = "dify-sandbox" FIRECRAWL_API_KEY = "fc-" +TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451" +TEI_RERANK_SERVER_URL = "http://a.abc.com:11451" [tool.poetry] name = "dify-api" diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py new file mode 100644 index 0000000000..2f66d707ca --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -0,0 +1,94 @@ + +from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter + + +class MockTEIClass: + @staticmethod + def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: + # During mock, we don't have a real server to query, so we just return a dummy value + if 'rerank' in model_name: + model_type = 'reranker' + else: + model_type = 'embedding' + + return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1) + + @staticmethod + def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: + # Use space as token separator, and split the text into tokens + tokenized_texts = [] + for text in texts: + tokens = text.split(' ') + current_index = 0 + tokenized_text = [] + for idx, token in enumerate(tokens): + s_token = { + 'id': idx, + 'text': token, + 'special': False, + 'start': current_index, + 'stop': current_index + len(token), + } + current_index += len(token) + 1 + tokenized_text.append(s_token) + tokenized_texts.append(tokenized_text) + return tokenized_texts + + @staticmethod + def invoke_embeddings(server_url: str, texts: list[str]) -> dict: + # { + # "object": "list", + # "data": [ + # { + # "object": "embedding", + # "embedding": [...], + # "index": 0 + # } + # ], + # "model": "MODEL_NAME", + # "usage": { + # "prompt_tokens": 3, + # "total_tokens": 3 + # } + # } + embeddings = [] + for idx, text in enumerate(texts): + embedding = [0.1] * 768 + embeddings.append( + { + 'object': 'embedding', + 'embedding': embedding, + 'index': idx, + } + ) + return { + 'object': 'list', + 'data': embeddings, + 'model': 'MODEL_NAME', + 'usage': { + 'prompt_tokens': sum(len(text.split(' ')) for text in texts), + 'total_tokens': sum(len(text.split(' ')) for text in texts), + }, + } + + def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: + # Example response: + # [ + # { + # "index": 0, + # "text": "Deep Learning is ...", + # "score": 0.9950755 + # } + # ] + reranked_docs = [] + for idx, text in enumerate(texts): + reranked_docs.append( + { + 'index': idx, + 'text': text, + 'score': 0.9, + } + ) + # For mock, only return the first document + break + return reranked_docs diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py b/api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py new file mode 100644 index 0000000000..da65c7dfc7 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -0,0 +1,72 @@ +import os + +import pytest +from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import ( + HuggingfaceTeiTextEmbeddingModel, +) +from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass + +MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + + +@pytest.fixture +def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): + if MOCK: + monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) + yield + + if MOCK: + monkeypatch.undo() + + +@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) +def test_validate_credentials(setup_tei_mock): + model = HuggingfaceTeiTextEmbeddingModel() + # model name is only used in mock + model_name = 'embedding' + + if MOCK: + # TEI Provider will check model type by API endpoint, at real server, the model type is correct. + # So we dont need to check model type here. Only check in mock + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='reranker', + credentials={ + 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), + } + ) + + model.validate_credentials( + model=model_name, + credentials={ + 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), + } + ) + +@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) +def test_invoke_model(setup_tei_mock): + model = HuggingfaceTeiTextEmbeddingModel() + model_name = 'embedding' + + result = model.invoke( + model=model_name, + credentials={ + 'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""), + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens > 0 diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py new file mode 100644 index 0000000000..57e229e6be --- /dev/null +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py @@ -0,0 +1,76 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import ( + HuggingfaceTeiRerankModel, +) +from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper +from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass + +MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + + +@pytest.fixture +def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch): + if MOCK: + monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter) + monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize) + monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings) + monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank) + yield + + if MOCK: + monkeypatch.undo() + +@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) +def test_validate_credentials(setup_tei_mock): + model = HuggingfaceTeiRerankModel() + # model name is only used in mock + model_name = 'reranker' + + if MOCK: + # TEI Provider will check model type by API endpoint, at real server, the model type is correct. + # So we dont need to check model type here. Only check in mock + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='embedding', + credentials={ + 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), + } + ) + + model.validate_credentials( + model=model_name, + credentials={ + 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), + } + ) + +@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True) +def test_invoke_model(setup_tei_mock): + model = HuggingfaceTeiRerankModel() + # model name is only used in mock + model_name = 'reranker' + + result = model.invoke( + model=model_name, + credentials={ + 'server_url': os.environ.get('TEI_RERANK_SERVER_URL'), + }, + query="Who is Kasumi?", + docs=[ + "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty." + ], + score_threshold=0.8 + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.8