diff --git a/api/core/model_providers/models/embedding/xinference_embedding.py b/api/core/model_providers/models/embedding/xinference_embedding.py index 839eeea357..ba8dd2d27d 100644 --- a/api/core/model_providers/models/embedding/xinference_embedding.py +++ b/api/core/model_providers/models/embedding/xinference_embedding.py @@ -1,4 +1,4 @@ -from langchain.embeddings import XinferenceEmbeddings +from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings from replicate.exceptions import ModelError, ReplicateError from core.model_providers.error import LLMBadRequestError @@ -14,7 +14,8 @@ class XinferenceEmbedding(BaseEmbedding): ) client = XinferenceEmbeddings( - **credentials, + server_url=credentials['server_url'], + model_uid=credentials['model_uid'], ) super().__init__(model_provider, client, name) diff --git a/api/core/third_party/langchain/embeddings/xinference_embedding.py b/api/core/third_party/langchain/embeddings/xinference_embedding.py new file mode 100644 index 0000000000..371e240e9f --- /dev/null +++ b/api/core/third_party/langchain/embeddings/xinference_embedding.py @@ -0,0 +1,21 @@ +from typing import List + +import numpy as np +from langchain.embeddings import XinferenceEmbeddings + + +class XinferenceEmbedding(XinferenceEmbeddings): + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + vectors = super().embed_documents(texts) + + normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors] + + return normalized_vectors + + def embed_query(self, text: str) -> List[float]: + vector = super().embed_query(text) + + normalized_vector = (vector / np.linalg.norm(vector)).tolist() + + return normalized_vector