fix: xinference chat support (#939)

This commit is contained in:
takatost 2023-08-21 20:44:29 +08:00 committed by GitHub
parent f53242c081
commit e0a48c4972
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 204 additions and 16 deletions

View File

@ -1,13 +1,13 @@
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import Xinference
from langchain.schema import LLMResult from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
class XinferenceModel(BaseLLM): class XinferenceModel(BaseLLM):
@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = Xinference( client = XinferenceLLM(
**self.credentials, server_url=self.credentials['server_url'],
model_uid=self.credentials['model_uid'],
) )
client.callbacks = self.callbacks client.callbacks = self.callbacks

View File

@ -1,7 +1,8 @@
import json import json
from typing import Type from typing import Type
from langchain.llms import Xinference import requests
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType from models.provider import ProviderType
@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
:param model_type: :param model_type:
:return: :return:
""" """
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules( return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1), temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7), top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0), presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0), frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256), max_tokens=KwargRule[int](min=10, max=4000, default=256),
) )
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
)
@classmethod @classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'], 'model_uid': credentials['model_uid'],
} }
llm = Xinference( llm = XinferenceLLM(
**credential_kwargs **credential_kwargs
) )
llm("ping", generate_config={'max_tokens': 10}) llm("ping")
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
:param credentials: :param credentials:
:return: :return:
""" """
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
return credentials return credentials
@classmethod
def _get_extra_credentials(self, credentials: dict) -> dict:
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get the model description, detail: {response.json()['detail']}"
)
desc = response.json()
extra_credentials = {
'model_format': desc['model_format'],
}
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
extra_credentials['model_handle_type'] = 'chatglm'
elif "generate" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'generate'
elif "chat" in desc["model_ability"]:
extra_credentials['model_handle_type'] = 'chat'
else:
raise NotImplementedError(f"Model handle type not supported.")
return extra_credentials
@classmethod @classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict): def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return return

View File

@ -0,0 +1,132 @@
from typing import Optional, List, Any, Union, Generator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import RESTfulChatglmCppChatModelHandle, \
RESTfulChatModelHandle, RESTfulGenerateModelHandle
class XinferenceLLM(Xinference):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the xinference model and return the output.
Args:
prompt: The prompt to use for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Returns:
The generated string by the model.
"""
model = self.client.get_model(self.model_uid)
if isinstance(model, RESTfulChatModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.chat(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulGenerateModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
completion = combined_text_output
else:
completion = model.chat(prompt=prompt, generate_config=generate_config)
completion = completion["choices"][0]["text"]
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
) -> Generator[str, None, None]:
"""
Args:
prompt: The prompt to use for generation.
model: The model used for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Yields:
A string token.
"""
if isinstance(model, RESTfulGenerateModelHandle):
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
else:
streaming_response = model.chat(
prompt=prompt, generate_config=generate_config
)
for chunk in streaming_response:
if isinstance(chunk, dict):
choices = chunk.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token

View File

@ -4,7 +4,6 @@ import json
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.replicate_provider import ReplicateProvider
from core.model_providers.providers.xinference_provider import XinferenceProvider from core.model_providers.providers.xinference_provider import XinferenceProvider
from models.provider import ProviderType, Provider, ProviderModel from models.provider import ProviderType, Provider, ProviderModel
@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_credentials_valid_or_raise_valid(mocker): def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.xinference.Xinference._call', mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
return_value="abc") return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt): def test_encrypt_model_credentials(mock_encrypt, mocker):
api_key = 'http://127.0.0.1:9997/' api_key = 'http://127.0.0.1:9997/'
mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
return_value={
'model_handle_type': 'generate',
'model_format': 'ggmlv3'
})
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id', tenant_id='tenant_id',
model_name='test_model_name', model_name='test_model_name',