mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 02:49:03 +08:00
add embedding input type parameter (#8724)
This commit is contained in:
parent
debe5953a8
commit
4669eb24be
@ -5,6 +5,7 @@ from typing import Optional, cast
|
||||
import numpy as np
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
@ -56,7 +57,9 @@ class CacheEmbedding(Embeddings):
|
||||
for i in range(0, len(embedding_queue_texts), max_chunks):
|
||||
batch_texts = embedding_queue_texts[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
|
||||
)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
@ -100,7 +103,9 @@ class CacheEmbedding(Embeddings):
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
|
10
api/core/embedding/embedding_constant.py
Normal file
10
api/core/embedding/embedding_constant.py
Normal file
@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EmbeddingInputType(Enum):
|
||||
"""
|
||||
Enum for embedding input type.
|
||||
"""
|
||||
|
||||
DOCUMENT = "document"
|
||||
QUERY = "query"
|
@ -3,6 +3,7 @@ import os
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from typing import IO, Optional, Union, cast
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
@ -158,12 +159,15 @@ class ModelInstance:
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
def invoke_text_embedding(
|
||||
self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
@ -176,6 +180,7 @@ class ModelInstance:
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@ -20,35 +21,47 @@ class TextEmbeddingModel(AIModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
return self._invoke(model, credentials, texts, user)
|
||||
return self._invoke(model, credentials, texts, user, input_type)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@ -17,7 +18,12 @@ from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_
|
||||
|
||||
class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
base_model_name = credentials["base_model_name"]
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -35,7 +36,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
api_base: str = "http://api.baichuan-ai.com/v1/embeddings"
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -13,6 +13,7 @@ from botocore.exceptions import (
|
||||
UnknownServiceError,
|
||||
)
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -30,7 +31,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -5,6 +5,7 @@ import cohere
|
||||
import numpy as np
|
||||
from cohere.core import RequestOptions
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -25,7 +26,12 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -6,6 +6,7 @@ import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import HfApi, InferenceClient
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@ -18,7 +19,12 @@ HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/
|
||||
|
||||
class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@ -23,7 +24,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -9,6 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -26,7 +27,12 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@ -60,7 +61,12 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
return data
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||
from requests import post
|
||||
from yarl import URL
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@ -26,7 +27,12 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -34,7 +35,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
api_base: str = "https://api.minimax.chat/v1/embeddings"
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@ -27,7 +28,12 @@ class MixedBreadTextEmbeddingModel(TextEmbeddingModel):
|
||||
api_base: str = "https://api.mixedbread.ai/v1"
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||
from nomic import embed
|
||||
from nomic import login as nomic_login
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
EmbeddingUsage,
|
||||
@ -46,6 +47,7 @@ class NomicTextEmbeddingModel(_CommonNomic, TextEmbeddingModel):
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -27,7 +28,12 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
|
||||
models: list[str] = ["NV-Embed-QA"]
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -6,6 +6,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import oci
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -41,7 +42,12 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
@ -50,6 +56,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
|
@ -8,6 +8,7 @@ from urllib.parse import urljoin
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
@ -38,7 +39,12 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -6,6 +6,7 @@ import numpy as np
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@ -19,7 +20,12 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -7,6 +7,7 @@ from urllib.parse import urljoin
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
@ -28,7 +29,12 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||
from requests import post
|
||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -25,7 +26,12 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -7,6 +7,7 @@ from urllib.parse import urljoin
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
@ -28,7 +29,12 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
@ -37,6 +43,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from replicate import Client as ReplicateClient
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@ -14,7 +15,12 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat
|
||||
|
||||
class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30)
|
||||
|
||||
|
@ -6,6 +6,7 @@ from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@ -53,7 +54,12 @@ class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||
return embeddings
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
|
||||
OAICompatEmbeddingModel,
|
||||
@ -16,7 +17,12 @@ class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel):
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, texts, user)
|
||||
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||
import dashscope
|
||||
import numpy as np
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
EmbeddingUsage,
|
||||
@ -27,6 +28,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
from openai import OpenAI
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@ -22,7 +23,14 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
|
||||
def _get_tokenizer(self) -> Tokenizer:
|
||||
return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer")
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | None = None) -> TextEmbeddingResult:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
|
@ -9,6 +9,7 @@ from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
@ -30,7 +31,12 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -2,6 +2,7 @@ import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
@ -41,7 +42,12 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -7,6 +7,7 @@ from typing import Any, Optional
|
||||
import numpy as np
|
||||
from requests import Response, post
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
@ -70,7 +71,12 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
||||
return WenxinTextEmbedding(api_key, secret_key)
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
@ -25,7 +26,12 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
@ -40,6 +46,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
server_url = credentials["server_url"]
|
||||
|
@ -1,6 +1,7 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@ -15,7 +16,12 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
Loading…
x
Reference in New Issue
Block a user