feat: use xinference client instead of xinference (#1339)

This commit is contained in:
takatost 2023-10-13 15:46:09 +08:00 committed by GitHub
parent 9822f687f7
commit 3efaa713da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 83 additions and 14 deletions

View File

@ -1,8 +1,7 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
class XinferenceEmbedding(BaseEmbedding):

View File

@ -2,7 +2,6 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@ -11,6 +10,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType

View File

@ -1,21 +1,54 @@
from typing import List
from typing import List, Optional, Any
import numpy as np
from langchain.embeddings import XinferenceEmbeddings
from langchain.embeddings.base import Embeddings
from xinference_client.client.restful.restful_client import Client
class XinferenceEmbedding(XinferenceEmbeddings):
class XinferenceEmbeddings(Embeddings):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
super().__init__()
if server_url is None:
raise ValueError("Please provide server URL")
if model_uid is None:
raise ValueError("Please provide the model UID")
self.server_url = server_url
self.model_uid = model_uid
self.client = Client(server_url)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
vectors = super().embed_documents(texts)
model = self.client.get_model(self.model_uid)
embeddings = [
model.create_embedding(text)["data"][0]["embedding"] for text in texts
]
vectors = [list(map(float, e)) for e in embeddings]
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
return normalized_vectors
def embed_query(self, text: str) -> List[float]:
vector = super().embed_query(text)
model = self.client.get_model(self.model_uid)
embedding_res = model.create_embedding(text)
embedding = embedding_res["data"][0]["embedding"]
vector = list(map(float, embedding))
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
return normalized_vector

View File

@ -1,16 +1,53 @@
from typing import Optional, List, Any, Union, Generator
from typing import Optional, List, Any, Union, Generator, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import (
from xinference_client.client.restful.restful_client import (
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
RESTfulGenerateModelHandle, Client,
)
class XinferenceLLM(Xinference):
class XinferenceLLM(LLM):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
super().__init__(
**{
"server_url": server_url,
"model_uid": model_uid,
}
)
if self.server_url is None:
raise ValueError("Please provide server URL")
if self.model_uid is None:
raise ValueError("Please provide the model UID")
self.client = Client(server_url)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinference"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"server_url": self.server_url},
**{"model_uid": self.model_uid},
}
def _call(
self,
prompt: str,

View File

@ -49,7 +49,7 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.5.2
xinference-client~=0.1.2
safetensors==0.3.2
zhipuai==1.0.7
werkzeug==2.3.7