mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 01:39:04 +08:00
feat: use xinference client instead of xinference (#1339)
This commit is contained in:
parent
9822f687f7
commit
3efaa713da
@ -1,8 +1,7 @@
|
|||||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
|
||||||
|
|
||||||
from core.model_providers.error import LLMBadRequestError
|
from core.model_providers.error import LLMBadRequestError
|
||||||
from core.model_providers.providers.base import BaseModelProvider
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||||
|
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
|
||||||
|
|
||||||
|
|
||||||
class XinferenceEmbedding(BaseEmbedding):
|
class XinferenceEmbedding(BaseEmbedding):
|
||||||
|
@ -2,7 +2,6 @@ import json
|
|||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain.embeddings import XinferenceEmbeddings
|
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||||
@ -11,6 +10,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
|
|||||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||||
|
|
||||||
from core.model_providers.models.base import BaseProviderModel
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
|
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
|
||||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
@ -1,21 +1,54 @@
|
|||||||
from typing import List
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain.embeddings import XinferenceEmbeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from xinference_client.client.restful.restful_client import Client
|
||||||
|
|
||||||
|
|
||||||
class XinferenceEmbedding(XinferenceEmbeddings):
|
class XinferenceEmbeddings(Embeddings):
|
||||||
|
client: Any
|
||||||
|
server_url: Optional[str]
|
||||||
|
"""URL of the xinference server"""
|
||||||
|
model_uid: Optional[str]
|
||||||
|
"""UID of the launched model"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if server_url is None:
|
||||||
|
raise ValueError("Please provide server URL")
|
||||||
|
|
||||||
|
if model_uid is None:
|
||||||
|
raise ValueError("Please provide the model UID")
|
||||||
|
|
||||||
|
self.server_url = server_url
|
||||||
|
|
||||||
|
self.model_uid = model_uid
|
||||||
|
|
||||||
|
self.client = Client(server_url)
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
vectors = super().embed_documents(texts)
|
model = self.client.get_model(self.model_uid)
|
||||||
|
|
||||||
|
embeddings = [
|
||||||
|
model.create_embedding(text)["data"][0]["embedding"] for text in texts
|
||||||
|
]
|
||||||
|
vectors = [list(map(float, e)) for e in embeddings]
|
||||||
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
|
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
|
||||||
|
|
||||||
return normalized_vectors
|
return normalized_vectors
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
vector = super().embed_query(text)
|
model = self.client.get_model(self.model_uid)
|
||||||
|
|
||||||
|
embedding_res = model.create_embedding(text)
|
||||||
|
|
||||||
|
embedding = embedding_res["data"][0]["embedding"]
|
||||||
|
|
||||||
|
vector = list(map(float, embedding))
|
||||||
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
|
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
|
||||||
|
|
||||||
return normalized_vector
|
return normalized_vector
|
||||||
|
@ -1,16 +1,53 @@
|
|||||||
from typing import Optional, List, Any, Union, Generator
|
from typing import Optional, List, Any, Union, Generator, Mapping
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms import Xinference
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from xinference.client import (
|
from xinference_client.client.restful.restful_client import (
|
||||||
RESTfulChatglmCppChatModelHandle,
|
RESTfulChatglmCppChatModelHandle,
|
||||||
RESTfulChatModelHandle,
|
RESTfulChatModelHandle,
|
||||||
RESTfulGenerateModelHandle,
|
RESTfulGenerateModelHandle, Client,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class XinferenceLLM(Xinference):
|
class XinferenceLLM(LLM):
|
||||||
|
client: Any
|
||||||
|
server_url: Optional[str]
|
||||||
|
"""URL of the xinference server"""
|
||||||
|
model_uid: Optional[str]
|
||||||
|
"""UID of the launched model"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
**{
|
||||||
|
"server_url": server_url,
|
||||||
|
"model_uid": model_uid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.server_url is None:
|
||||||
|
raise ValueError("Please provide server URL")
|
||||||
|
|
||||||
|
if self.model_uid is None:
|
||||||
|
raise ValueError("Please provide the model UID")
|
||||||
|
|
||||||
|
self.client = Client(server_url)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "xinference"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
**{"server_url": self.server_url},
|
||||||
|
**{"model_uid": self.model_uid},
|
||||||
|
}
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -49,7 +49,7 @@ huggingface_hub~=0.16.4
|
|||||||
transformers~=4.31.0
|
transformers~=4.31.0
|
||||||
stripe~=5.5.0
|
stripe~=5.5.0
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
xinference==0.5.2
|
xinference-client~=0.1.2
|
||||||
safetensors==0.3.2
|
safetensors==0.3.2
|
||||||
zhipuai==1.0.7
|
zhipuai==1.0.7
|
||||||
werkzeug==2.3.7
|
werkzeug==2.3.7
|
||||||
|
Loading…
x
Reference in New Issue
Block a user