mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 03:29:01 +08:00
Feat: Add model provider Text Embedding Inference for embedding and rerank (#7132)
This commit is contained in:
parent
4cbeb6815b
commit
5b32f2e0dd
@ -0,0 +1,11 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceTeiProvider(ModelProvider):
|
||||||
|
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
pass
|
@ -0,0 +1,36 @@
|
|||||||
|
provider: huggingface_tei
|
||||||
|
label:
|
||||||
|
en_US: Text Embedding Inference
|
||||||
|
description:
|
||||||
|
en_US: A blazing fast inference solution for text embeddings models.
|
||||||
|
zh_Hans: 用于文本嵌入模型的超快速推理解决方案。
|
||||||
|
background: "#FFF8DC"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: How to deploy Text Embedding Inference
|
||||||
|
zh_Hans: 如何部署 Text Embedding Inference
|
||||||
|
url:
|
||||||
|
en_US: https://github.com/huggingface/text-embeddings-inference
|
||||||
|
supported_model_types:
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
configurate_methods:
|
||||||
|
- customizable-model
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model Name
|
||||||
|
zh_Hans: 模型名称
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your model name
|
||||||
|
zh_Hans: 输入模型名称
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: server_url
|
||||||
|
label:
|
||||||
|
zh_Hans: 服务器URL
|
||||||
|
en_US: Server url
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
|
||||||
|
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
|
@ -0,0 +1,137 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceTeiRerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
Model class for Text Embedding Inference rerank model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(model=model, docs=[])
|
||||||
|
server_url = credentials['server_url']
|
||||||
|
|
||||||
|
if server_url.endswith('/'):
|
||||||
|
server_url = server_url[:-1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = TeiHelper.invoke_rerank(server_url, query, docs)
|
||||||
|
|
||||||
|
rerank_documents = []
|
||||||
|
for result in results:
|
||||||
|
rerank_document = RerankDocument(
|
||||||
|
index=result['index'],
|
||||||
|
text=result['text'],
|
||||||
|
score=result['score'],
|
||||||
|
)
|
||||||
|
if score_threshold is None or result['score'] >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
if top_n is not None and len(rerank_documents) >= top_n:
|
||||||
|
break
|
||||||
|
|
||||||
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
server_url = credentials['server_url']
|
||||||
|
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||||
|
if extra_args.model_type != 'reranker':
|
||||||
|
raise CredentialsValidateFailedError('Current model is not a rerank model')
|
||||||
|
|
||||||
|
credentials['context_size'] = extra_args.max_input_length
|
||||||
|
|
||||||
|
self.invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query='Whose kasumi',
|
||||||
|
docs=[
|
||||||
|
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
|
||||||
|
'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ',
|
||||||
|
'and she leads a team named PopiParty.',
|
||||||
|
],
|
||||||
|
score_threshold=0.8,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [InvokeConnectionError],
|
||||||
|
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||||
|
InvokeRateLimitError: [InvokeRateLimitError],
|
||||||
|
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||||
|
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
|
||||||
|
},
|
||||||
|
parameter_rules=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -0,0 +1,183 @@
|
|||||||
|
from threading import Lock
|
||||||
|
from time import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from requests.exceptions import ConnectionError, MissingSchema, Timeout
|
||||||
|
from requests.sessions import Session
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
|
||||||
|
class TeiModelExtraParameter:
|
||||||
|
model_type: str
|
||||||
|
max_input_length: int
|
||||||
|
max_client_batch_size: int
|
||||||
|
|
||||||
|
def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None:
|
||||||
|
self.model_type = model_type
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.max_client_batch_size = max_client_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
cache_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class TeiHelper:
|
||||||
|
@staticmethod
|
||||||
|
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||||
|
TeiHelper._clean_cache()
|
||||||
|
with cache_lock:
|
||||||
|
if model_name not in cache:
|
||||||
|
cache[model_name] = {
|
||||||
|
'expires': time() + 300,
|
||||||
|
'value': TeiHelper._get_tei_extra_parameter(server_url),
|
||||||
|
}
|
||||||
|
return cache[model_name]['value']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clean_cache() -> None:
|
||||||
|
try:
|
||||||
|
with cache_lock:
|
||||||
|
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
|
||||||
|
for model_uid in expired_keys:
|
||||||
|
del cache[model_uid]
|
||||||
|
except RuntimeError as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
|
||||||
|
"""
|
||||||
|
get tei model extra parameter like model_type, max_input_length, max_batch_requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = str(URL(server_url) / 'info')
|
||||||
|
|
||||||
|
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
||||||
|
session = Session()
|
||||||
|
session.mount('http://', HTTPAdapter(max_retries=3))
|
||||||
|
session.mount('https://', HTTPAdapter(max_retries=3))
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = session.get(url, timeout=10)
|
||||||
|
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||||
|
raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}')
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}'
|
||||||
|
)
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
model_type = response_json.get('model_type', {})
|
||||||
|
if len(model_type.keys()) < 1:
|
||||||
|
raise RuntimeError('model_type is empty')
|
||||||
|
model_type = list(model_type.keys())[0]
|
||||||
|
if model_type not in ['embedding', 'reranker']:
|
||||||
|
raise RuntimeError(f'invalid model_type: {model_type}')
|
||||||
|
|
||||||
|
max_input_length = response_json.get('max_input_length', 512)
|
||||||
|
max_client_batch_size = response_json.get('max_client_batch_size', 1)
|
||||||
|
|
||||||
|
return TeiModelExtraParameter(
|
||||||
|
model_type=model_type,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
max_client_batch_size=max_client_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||||
|
"""
|
||||||
|
Invoke tokenize endpoint
|
||||||
|
|
||||||
|
Example response:
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"text": "<s>",
|
||||||
|
"special": true,
|
||||||
|
"start": null,
|
||||||
|
"stop": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7704,
|
||||||
|
"text": "str",
|
||||||
|
"special": false,
|
||||||
|
"start": 0,
|
||||||
|
"stop": 3
|
||||||
|
},
|
||||||
|
< MORE TOKENS >
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
:param server_url: server url
|
||||||
|
:param texts: texts to tokenize
|
||||||
|
"""
|
||||||
|
resp = httpx.post(
|
||||||
|
f'{server_url}/tokenize',
|
||||||
|
json={'inputs': texts},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||||
|
"""
|
||||||
|
Invoke embeddings endpoint
|
||||||
|
|
||||||
|
Example response:
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": [...],
|
||||||
|
"index": 0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "MODEL_NAME",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 3,
|
||||||
|
"total_tokens": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
:param server_url: server url
|
||||||
|
:param texts: texts to embed
|
||||||
|
"""
|
||||||
|
# Use OpenAI compatible API here, which has usage tracking
|
||||||
|
resp = httpx.post(
|
||||||
|
f'{server_url}/v1/embeddings',
|
||||||
|
json={'input': texts},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Invoke rerank endpoint
|
||||||
|
|
||||||
|
Example response:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"text": "Deep Learning is ...",
|
||||||
|
"score": 0.9950755
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
:param server_url: server url
|
||||||
|
:param texts: texts to rerank
|
||||||
|
:param candidates: candidates to rerank
|
||||||
|
"""
|
||||||
|
params = {'query': query, 'texts': docs, 'return_text': True}
|
||||||
|
|
||||||
|
response = httpx.post(
|
||||||
|
server_url + '/rerank',
|
||||||
|
json=params,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
@ -0,0 +1,204 @@
|
|||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Text Embedding Inference text embedding model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
credentials should be like:
|
||||||
|
{
|
||||||
|
'server_url': 'server url',
|
||||||
|
'model_uid': 'model uid',
|
||||||
|
}
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
server_url = credentials['server_url']
|
||||||
|
|
||||||
|
if server_url.endswith('/'):
|
||||||
|
server_url = server_url[:-1]
|
||||||
|
|
||||||
|
|
||||||
|
# get model properties
|
||||||
|
context_size = self._get_context_size(model, credentials)
|
||||||
|
max_chunks = self._get_max_chunks(model, credentials)
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
indices = []
|
||||||
|
used_tokens = 0
|
||||||
|
|
||||||
|
# get tokenized results from TEI
|
||||||
|
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
|
||||||
|
|
||||||
|
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
|
||||||
|
|
||||||
|
# Check if the number of tokens is larger than the context size
|
||||||
|
num_tokens = len(tokenize_result)
|
||||||
|
|
||||||
|
if num_tokens >= context_size:
|
||||||
|
# Find the best cutoff point
|
||||||
|
pre_special_token_count = 0
|
||||||
|
for token in tokenize_result:
|
||||||
|
if token['special']:
|
||||||
|
pre_special_token_count += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count
|
||||||
|
|
||||||
|
# Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
|
||||||
|
token_cutoff = context_size - rest_special_token_count - 20
|
||||||
|
|
||||||
|
# Find the cutoff index
|
||||||
|
cutpoint_token = tokenize_result[token_cutoff]
|
||||||
|
cutoff = cutpoint_token['start']
|
||||||
|
|
||||||
|
inputs.append(text[0: cutoff])
|
||||||
|
else:
|
||||||
|
inputs.append(text)
|
||||||
|
indices += [i]
|
||||||
|
|
||||||
|
batched_embeddings = []
|
||||||
|
_iter = range(0, len(inputs), max_chunks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
used_tokens = 0
|
||||||
|
for i in _iter:
|
||||||
|
iter_texts = inputs[i : i + max_chunks]
|
||||||
|
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
|
||||||
|
embeddings = results['data']
|
||||||
|
embeddings = [embedding['embedding'] for embedding in embeddings]
|
||||||
|
batched_embeddings.extend(embeddings)
|
||||||
|
|
||||||
|
usage = results['usage']
|
||||||
|
used_tokens += usage['total_tokens']
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||||
|
|
||||||
|
result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
num_tokens = 0
|
||||||
|
server_url = credentials['server_url']
|
||||||
|
|
||||||
|
if server_url.endswith('/'):
|
||||||
|
server_url = server_url[:-1]
|
||||||
|
|
||||||
|
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
|
||||||
|
num_tokens = sum(len(tokens) for tokens in batch_tokens)
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
server_url = credentials['server_url']
|
||||||
|
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
|
||||||
|
print(extra_args)
|
||||||
|
if extra_args.model_type != 'embedding':
|
||||||
|
raise CredentialsValidateFailedError('Current model is not a embedding model')
|
||||||
|
|
||||||
|
credentials['context_size'] = extra_args.max_input_length
|
||||||
|
credentials['max_chunks'] = extra_args.max_client_batch_size
|
||||||
|
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [InvokeConnectionError],
|
||||||
|
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||||
|
InvokeRateLimitError: [InvokeRateLimitError],
|
||||||
|
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||||
|
InvokeBadRequestError: [KeyError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)),
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
|
||||||
|
},
|
||||||
|
parameter_rules=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
@ -93,6 +93,8 @@ CODE_MAX_STRING_LENGTH = "80000"
|
|||||||
CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194"
|
CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194"
|
||||||
CODE_EXECUTION_API_KEY = "dify-sandbox"
|
CODE_EXECUTION_API_KEY = "dify-sandbox"
|
||||||
FIRECRAWL_API_KEY = "fc-"
|
FIRECRAWL_API_KEY = "fc-"
|
||||||
|
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
|
||||||
|
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
|
@ -0,0 +1,94 @@
|
|||||||
|
|
||||||
|
from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
|
||||||
|
|
||||||
|
|
||||||
|
class MockTEIClass:
|
||||||
|
@staticmethod
|
||||||
|
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||||
|
# During mock, we don't have a real server to query, so we just return a dummy value
|
||||||
|
if 'rerank' in model_name:
|
||||||
|
model_type = 'reranker'
|
||||||
|
else:
|
||||||
|
model_type = 'embedding'
|
||||||
|
|
||||||
|
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
|
||||||
|
# Use space as token separator, and split the text into tokens
|
||||||
|
tokenized_texts = []
|
||||||
|
for text in texts:
|
||||||
|
tokens = text.split(' ')
|
||||||
|
current_index = 0
|
||||||
|
tokenized_text = []
|
||||||
|
for idx, token in enumerate(tokens):
|
||||||
|
s_token = {
|
||||||
|
'id': idx,
|
||||||
|
'text': token,
|
||||||
|
'special': False,
|
||||||
|
'start': current_index,
|
||||||
|
'stop': current_index + len(token),
|
||||||
|
}
|
||||||
|
current_index += len(token) + 1
|
||||||
|
tokenized_text.append(s_token)
|
||||||
|
tokenized_texts.append(tokenized_text)
|
||||||
|
return tokenized_texts
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
|
||||||
|
# {
|
||||||
|
# "object": "list",
|
||||||
|
# "data": [
|
||||||
|
# {
|
||||||
|
# "object": "embedding",
|
||||||
|
# "embedding": [...],
|
||||||
|
# "index": 0
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
# "model": "MODEL_NAME",
|
||||||
|
# "usage": {
|
||||||
|
# "prompt_tokens": 3,
|
||||||
|
# "total_tokens": 3
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
embeddings = []
|
||||||
|
for idx, text in enumerate(texts):
|
||||||
|
embedding = [0.1] * 768
|
||||||
|
embeddings.append(
|
||||||
|
{
|
||||||
|
'object': 'embedding',
|
||||||
|
'embedding': embedding,
|
||||||
|
'index': idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'object': 'list',
|
||||||
|
'data': embeddings,
|
||||||
|
'model': 'MODEL_NAME',
|
||||||
|
'usage': {
|
||||||
|
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
|
||||||
|
'total_tokens': sum(len(text.split(' ')) for text in texts),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
|
||||||
|
# Example response:
|
||||||
|
# [
|
||||||
|
# {
|
||||||
|
# "index": 0,
|
||||||
|
# "text": "Deep Learning is ...",
|
||||||
|
# "score": 0.9950755
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
reranked_docs = []
|
||||||
|
for idx, text in enumerate(texts):
|
||||||
|
reranked_docs.append(
|
||||||
|
{
|
||||||
|
'index': idx,
|
||||||
|
'text': text,
|
||||||
|
'score': 0.9,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# For mock, only return the first document
|
||||||
|
break
|
||||||
|
return reranked_docs
|
@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
|
||||||
|
HuggingfaceTeiTextEmbeddingModel,
|
||||||
|
)
|
||||||
|
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||||
|
|
||||||
|
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
if MOCK:
|
||||||
|
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
|
||||||
|
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
|
||||||
|
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
|
||||||
|
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
|
||||||
|
yield
|
||||||
|
|
||||||
|
if MOCK:
|
||||||
|
monkeypatch.undo()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||||
|
def test_validate_credentials(setup_tei_mock):
|
||||||
|
model = HuggingfaceTeiTextEmbeddingModel()
|
||||||
|
# model name is only used in mock
|
||||||
|
model_name = 'embedding'
|
||||||
|
|
||||||
|
if MOCK:
|
||||||
|
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||||
|
# So we dont need to check model type here. Only check in mock
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model='reranker',
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model=model_name,
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||||
|
def test_invoke_model(setup_tei_mock):
|
||||||
|
model = HuggingfaceTeiTextEmbeddingModel()
|
||||||
|
model_name = 'embedding'
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model=model_name,
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||||
|
},
|
||||||
|
texts=[
|
||||||
|
"hello",
|
||||||
|
"world"
|
||||||
|
],
|
||||||
|
user="abc-123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
|
assert len(result.embeddings) == 2
|
||||||
|
assert result.usage.total_tokens > 0
|
@ -0,0 +1,76 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
|
||||||
|
HuggingfaceTeiRerankModel,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
|
||||||
|
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||||
|
|
||||||
|
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
if MOCK:
|
||||||
|
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
|
||||||
|
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
|
||||||
|
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
|
||||||
|
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
|
||||||
|
yield
|
||||||
|
|
||||||
|
if MOCK:
|
||||||
|
monkeypatch.undo()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||||
|
def test_validate_credentials(setup_tei_mock):
|
||||||
|
model = HuggingfaceTeiRerankModel()
|
||||||
|
# model name is only used in mock
|
||||||
|
model_name = 'reranker'
|
||||||
|
|
||||||
|
if MOCK:
|
||||||
|
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||||
|
# So we dont need to check model type here. Only check in mock
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model='embedding',
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model=model_name,
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||||
|
def test_invoke_model(setup_tei_mock):
|
||||||
|
model = HuggingfaceTeiRerankModel()
|
||||||
|
# model name is only used in mock
|
||||||
|
model_name = 'reranker'
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model=model_name,
|
||||||
|
credentials={
|
||||||
|
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||||
|
},
|
||||||
|
query="Who is Kasumi?",
|
||||||
|
docs=[
|
||||||
|
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
|
||||||
|
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||||
|
"and she leads a team named PopiParty."
|
||||||
|
],
|
||||||
|
score_threshold=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, RerankResult)
|
||||||
|
assert len(result.docs) == 1
|
||||||
|
assert result.docs[0].index == 0
|
||||||
|
assert result.docs[0].score >= 0.8
|
Loading…
x
Reference in New Issue
Block a user