Feat: Add model provider Text Embedding Inference for embedding and rerank (#7132)

This commit is contained in:
Yanyi Liu 2024-08-09 19:12:13 +08:00 committed by GitHub
parent 4cbeb6815b
commit 5b32f2e0dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 815 additions and 0 deletions

View File

@ -0,0 +1,11 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class HuggingfaceTeiProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,36 @@
provider: huggingface_tei
label:
en_US: Text Embedding Inference
description:
en_US: A blazing fast inference solution for text embeddings models.
zh_Hans: 用于文本嵌入模型的超快速推理解决方案。
background: "#FFF8DC"
help:
title:
en_US: How to deploy Text Embedding Inference
zh_Hans: 如何部署 Text Embedding Inference
url:
en_US: https://github.com/huggingface/text-embeddings-inference
supported_model_types:
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Text Embedding Inference的服务器地址如 http://192.168.1.100:8080
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080

View File

@ -0,0 +1,137 @@
from typing import Optional
import httpx
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
class HuggingfaceTeiRerankModel(RerankModel):
"""
Model class for Text Embedding Inference rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
server_url = credentials['server_url']
if server_url.endswith('/'):
server_url = server_url[:-1]
try:
results = TeiHelper.invoke_rerank(server_url, query, docs)
rerank_documents = []
for result in results:
rerank_document = RerankDocument(
index=result['index'],
text=result['text'],
score=result['score'],
)
if score_threshold is None or result['score'] >= score_threshold:
rerank_documents.append(rerank_document)
if top_n is not None and len(rerank_documents) >= top_n:
break
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
server_url = credentials['server_url']
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
if extra_args.model_type != 'reranker':
raise CredentialsValidateFailedError('Current model is not a rerank model')
credentials['context_size'] = extra_args.max_input_length
self.invoke(
model=model,
credentials=credentials,
query='Whose kasumi',
docs=[
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ',
'and she leads a team named PopiParty.',
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
},
parameter_rules=[],
)
return entity

View File

@ -0,0 +1,183 @@
from threading import Lock
from time import time
from typing import Optional
import httpx
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
from requests.sessions import Session
from yarl import URL
class TeiModelExtraParameter:
model_type: str
max_input_length: int
max_client_batch_size: int
def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None:
self.model_type = model_type
self.max_input_length = max_input_length
self.max_client_batch_size = max_client_batch_size
cache = {}
cache_lock = Lock()
class TeiHelper:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
TeiHelper._clean_cache()
with cache_lock:
if model_name not in cache:
cache[model_name] = {
'expires': time() + 300,
'value': TeiHelper._get_tei_extra_parameter(server_url),
}
return cache[model_name]['value']
@staticmethod
def _clean_cache() -> None:
try:
with cache_lock:
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
for model_uid in expired_keys:
del cache[model_uid]
except RuntimeError as e:
pass
@staticmethod
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
"""
get tei model extra parameter like model_type, max_input_length, max_batch_requests
"""
url = str(URL(server_url) / 'info')
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3))
try:
response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200:
raise RuntimeError(
f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}'
)
response_json = response.json()
model_type = response_json.get('model_type', {})
if len(model_type.keys()) < 1:
raise RuntimeError('model_type is empty')
model_type = list(model_type.keys())[0]
if model_type not in ['embedding', 'reranker']:
raise RuntimeError(f'invalid model_type: {model_type}')
max_input_length = response_json.get('max_input_length', 512)
max_client_batch_size = response_json.get('max_client_batch_size', 1)
return TeiModelExtraParameter(
model_type=model_type,
max_input_length=max_input_length,
max_client_batch_size=max_client_batch_size
)
@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
"""
Invoke tokenize endpoint
Example response:
[
[
{
"id": 0,
"text": "<s>",
"special": true,
"start": null,
"stop": null
},
{
"id": 7704,
"text": "str",
"special": false,
"start": 0,
"stop": 3
},
< MORE TOKENS >
]
]
:param server_url: server url
:param texts: texts to tokenize
"""
resp = httpx.post(
f'{server_url}/tokenize',
json={'inputs': texts},
)
resp.raise_for_status()
return resp.json()
@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
"""
Invoke embeddings endpoint
Example response:
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [...],
"index": 0
}
],
"model": "MODEL_NAME",
"usage": {
"prompt_tokens": 3,
"total_tokens": 3
}
}
:param server_url: server url
:param texts: texts to embed
"""
# Use OpenAI compatible API here, which has usage tracking
resp = httpx.post(
f'{server_url}/v1/embeddings',
json={'input': texts},
)
resp.raise_for_status()
return resp.json()
@staticmethod
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
"""
Invoke rerank endpoint
Example response:
[
{
"index": 0,
"text": "Deep Learning is ...",
"score": 0.9950755
}
]
:param server_url: server url
:param texts: texts to rerank
:param candidates: candidates to rerank
"""
params = {'query': query, 'texts': docs, 'return_text': True}
response = httpx.post(
server_url + '/rerank',
json=params,
)
response.raise_for_status()
return response.json()

View File

@ -0,0 +1,204 @@
import time
from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Text Embedding Inference text embedding model.
"""
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
credentials should be like:
{
'server_url': 'server url',
'model_uid': 'model uid',
}
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
server_url = credentials['server_url']
if server_url.endswith('/'):
server_url = server_url[:-1]
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
# get tokenized results from TEI
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
# Check if the number of tokens is larger than the context size
num_tokens = len(tokenize_result)
if num_tokens >= context_size:
# Find the best cutoff point
pre_special_token_count = 0
for token in tokenize_result:
if token['special']:
pre_special_token_count += 1
else:
break
rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count
# Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
token_cutoff = context_size - rest_special_token_count - 20
# Find the cutoff index
cutpoint_token = tokenize_result[token_cutoff]
cutoff = cutpoint_token['start']
inputs.append(text[0: cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
try:
used_tokens = 0
for i in _iter:
iter_texts = inputs[i : i + max_chunks]
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
embeddings = results['data']
embeddings = [embedding['embedding'] for embedding in embeddings]
batched_embeddings.extend(embeddings)
usage = results['usage']
used_tokens += usage['total_tokens']
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage)
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
num_tokens = 0
server_url = credentials['server_url']
if server_url.endswith('/'):
server_url = server_url[:-1]
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
num_tokens = sum(len(tokens) for tokens in batch_tokens)
return num_tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
server_url = credentials['server_url']
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
print(extra_args)
if extra_args.model_type != 'embedding':
raise CredentialsValidateFailedError('Current model is not a embedding model')
credentials['context_size'] = extra_args.max_input_length
credentials['max_chunks'] = extra_args.max_client_batch_size
self._invoke(model=model, credentials=credentials, texts=['ping'])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at,
)
return usage
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
},
parameter_rules=[],
)
return entity

View File

@ -93,6 +93,8 @@ CODE_MAX_STRING_LENGTH = "80000"
CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194"
CODE_EXECUTION_API_KEY = "dify-sandbox" CODE_EXECUTION_API_KEY = "dify-sandbox"
FIRECRAWL_API_KEY = "fc-" FIRECRAWL_API_KEY = "fc-"
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
[tool.poetry] [tool.poetry]
name = "dify-api" name = "dify-api"

View File

@ -0,0 +1,94 @@
from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
class MockTEIClass:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
# During mock, we don't have a real server to query, so we just return a dummy value
if 'rerank' in model_name:
model_type = 'reranker'
else:
model_type = 'embedding'
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
# Use space as token separator, and split the text into tokens
tokenized_texts = []
for text in texts:
tokens = text.split(' ')
current_index = 0
tokenized_text = []
for idx, token in enumerate(tokens):
s_token = {
'id': idx,
'text': token,
'special': False,
'start': current_index,
'stop': current_index + len(token),
}
current_index += len(token) + 1
tokenized_text.append(s_token)
tokenized_texts.append(tokenized_text)
return tokenized_texts
@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
# {
# "object": "list",
# "data": [
# {
# "object": "embedding",
# "embedding": [...],
# "index": 0
# }
# ],
# "model": "MODEL_NAME",
# "usage": {
# "prompt_tokens": 3,
# "total_tokens": 3
# }
# }
embeddings = []
for idx, text in enumerate(texts):
embedding = [0.1] * 768
embeddings.append(
{
'object': 'embedding',
'embedding': embedding,
'index': idx,
}
)
return {
'object': 'list',
'data': embeddings,
'model': 'MODEL_NAME',
'usage': {
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
'total_tokens': sum(len(text.split(' ')) for text in texts),
},
}
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
# Example response:
# [
# {
# "index": 0,
# "text": "Deep Learning is ...",
# "score": 0.9950755
# }
# ]
reranked_docs = []
for idx, text in enumerate(texts):
reranked_docs.append(
{
'index': idx,
'text': text,
'score': 0.9,
}
)
# For mock, only return the first document
break
return reranked_docs

View File

@ -0,0 +1,72 @@
import os
import pytest
from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
HuggingfaceTeiTextEmbeddingModel,
)
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
yield
if MOCK:
monkeypatch.undo()
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
# model name is only used in mock
model_name = 'embedding'
if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='reranker',
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
}
)
model.validate_credentials(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
}
)
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
model_name = 'embedding'
result = model.invoke(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens > 0

View File

@ -0,0 +1,76 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
HuggingfaceTeiRerankModel,
)
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
yield
if MOCK:
monkeypatch.undo()
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = 'reranker'
if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embedding',
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
}
)
model.validate_credentials(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
}
)
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = 'reranker'
result = model.invoke(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty."
],
score_threshold=0.8
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 0
assert result.docs[0].score >= 0.8