mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 17:49:02 +08:00
refactor: reduce duplciate code by inheritance (#13073)
This commit is contained in:
parent
23c68efa2d
commit
d44882c1b5
@ -1,29 +1,13 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
from decimal import Decimal
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from core.entities.embedding_type import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
|
||||||
AIModelEntity,
|
OAICompatEmbeddingModel,
|
||||||
FetchFrom,
|
|
||||||
ModelPropertyKey,
|
|
||||||
ModelType,
|
|
||||||
PriceConfig,
|
|
||||||
PriceType,
|
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
||||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
|
||||||
|
|
||||||
|
|
||||||
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
class PerfXCloudEmbeddingModel(OAICompatEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
Model class for an OpenAI API-compatible text embedding model.
|
Model class for an OpenAI API-compatible text embedding model.
|
||||||
"""
|
"""
|
||||||
@ -47,86 +31,10 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
|||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Prepare headers and payload for the request
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
api_key = credentials.get("api_key")
|
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
|
||||||
endpoint_url: Optional[str]
|
|
||||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||||
endpoint_url = "https://cloud.perfxlab.cn/v1/"
|
credentials["endpoint_url"] = "https://cloud.perfxlab.cn/v1/"
|
||||||
else:
|
|
||||||
endpoint_url = credentials.get("endpoint_url")
|
|
||||||
assert endpoint_url is not None, "endpoint_url is required in credentials"
|
|
||||||
if not endpoint_url.endswith("/"):
|
|
||||||
endpoint_url += "/"
|
|
||||||
|
|
||||||
assert isinstance(endpoint_url, str)
|
return OAICompatEmbeddingModel._invoke(self, model, credentials, texts, user, input_type)
|
||||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
|
||||||
if user:
|
|
||||||
extra_model_kwargs["user"] = user
|
|
||||||
|
|
||||||
extra_model_kwargs["encoding_format"] = "float"
|
|
||||||
|
|
||||||
# get model properties
|
|
||||||
context_size = self._get_context_size(model, credentials)
|
|
||||||
max_chunks = self._get_max_chunks(model, credentials)
|
|
||||||
|
|
||||||
inputs = []
|
|
||||||
indices = []
|
|
||||||
used_tokens = 0
|
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
|
||||||
# TODO: Optimize for better token estimation and chunking
|
|
||||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
|
||||||
|
|
||||||
if num_tokens >= context_size:
|
|
||||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
|
||||||
# if num tokens is larger than context length, only use the start
|
|
||||||
inputs.append(text[0:cutoff])
|
|
||||||
else:
|
|
||||||
inputs.append(text)
|
|
||||||
indices += [i]
|
|
||||||
|
|
||||||
batched_embeddings = []
|
|
||||||
_iter = range(0, len(inputs), max_chunks)
|
|
||||||
|
|
||||||
for i in _iter:
|
|
||||||
# Prepare the payload for the request
|
|
||||||
payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs}
|
|
||||||
|
|
||||||
# Make the request to the OpenAI API
|
|
||||||
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
|
||||||
|
|
||||||
response.raise_for_status() # Raise an exception for HTTP errors
|
|
||||||
response_data = response.json()
|
|
||||||
|
|
||||||
# Extract embeddings and used tokens from the response
|
|
||||||
embeddings_batch = [data["embedding"] for data in response_data["data"]]
|
|
||||||
embedding_used_tokens = response_data["usage"]["total_tokens"]
|
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
|
||||||
batched_embeddings += embeddings_batch
|
|
||||||
|
|
||||||
# calc usage
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
|
||||||
|
|
||||||
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
|
||||||
"""
|
|
||||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param texts: texts to embed
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -136,93 +44,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
|||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
api_key = credentials.get("api_key")
|
|
||||||
|
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
|
||||||
|
|
||||||
endpoint_url: Optional[str]
|
|
||||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||||
endpoint_url = "https://cloud.perfxlab.cn/v1/"
|
credentials["endpoint_url"] = "https://cloud.perfxlab.cn/v1/"
|
||||||
else:
|
|
||||||
endpoint_url = credentials.get("endpoint_url")
|
|
||||||
assert endpoint_url is not None, "endpoint_url is required in credentials"
|
|
||||||
if not endpoint_url.endswith("/"):
|
|
||||||
endpoint_url += "/"
|
|
||||||
|
|
||||||
assert isinstance(endpoint_url, str)
|
OAICompatEmbeddingModel.validate_credentials(self, model, credentials)
|
||||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
|
||||||
|
|
||||||
payload = {"input": "ping", "model": model}
|
|
||||||
|
|
||||||
response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise CredentialsValidateFailedError(
|
|
||||||
f"Credentials validation failed with status code {response.status_code}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
json_result = response.json()
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
|
|
||||||
|
|
||||||
if "model" not in json_result:
|
|
||||||
raise CredentialsValidateFailedError("Credentials validation failed: invalid response")
|
|
||||||
except CredentialsValidateFailedError:
|
|
||||||
raise
|
|
||||||
except Exception as ex:
|
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
||||||
"""
|
|
||||||
generate custom model entities from credentials
|
|
||||||
"""
|
|
||||||
entity = AIModelEntity(
|
|
||||||
model=model,
|
|
||||||
label=I18nObject(en_US=model),
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
||||||
model_properties={
|
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
|
|
||||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
|
||||||
},
|
|
||||||
parameter_rules=[],
|
|
||||||
pricing=PriceConfig(
|
|
||||||
input=Decimal(credentials.get("input_price", 0)),
|
|
||||||
unit=Decimal(credentials.get("unit", 0)),
|
|
||||||
currency=credentials.get("currency", "USD"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return entity
|
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
||||||
"""
|
|
||||||
Calculate response usage
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param tokens: input tokens
|
|
||||||
:return: usage
|
|
||||||
"""
|
|
||||||
# get input price info
|
|
||||||
input_price_info = self.get_price(
|
|
||||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = EmbeddingUsage(
|
|
||||||
tokens=tokens,
|
|
||||||
total_tokens=tokens,
|
|
||||||
unit_price=input_price_info.unit_price,
|
|
||||||
price_unit=input_price_info.unit,
|
|
||||||
total_price=input_price_info.total_amount,
|
|
||||||
currency=input_price_info.currency,
|
|
||||||
latency=time.perf_counter() - self.started_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
return usage
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user