From e0a48c4972f47929fe99d4992e419a7974968aaf Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 21 Aug 2023 20:44:29 +0800 Subject: [PATCH] fix: xinference chat support (#939) --- .../models/llm/xinference_model.py | 7 +- .../providers/xinference_provider.py | 69 +++++++-- .../langchain/llms/xinference_llm.py | 132 ++++++++++++++++++ .../test_xinference_provider.py | 12 +- 4 files changed, 204 insertions(+), 16 deletions(-) create mode 100644 api/core/third_party/langchain/llms/xinference_llm.py diff --git a/api/core/model_providers/models/llm/xinference_model.py b/api/core/model_providers/models/llm/xinference_model.py index ef3a83c352..8af6356c99 100644 --- a/api/core/model_providers/models/llm/xinference_model.py +++ b/api/core/model_providers/models/llm/xinference_model.py @@ -1,13 +1,13 @@ from typing import List, Optional, Any from langchain.callbacks.manager import Callbacks -from langchain.llms import Xinference from langchain.schema import LLMResult from core.model_providers.error import LLMBadRequestError from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.xinference_llm import XinferenceLLM class XinferenceModel(BaseLLM): @@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM): def _init_client(self) -> Any: self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) - client = Xinference( - **self.credentials, + client = XinferenceLLM( + server_url=self.credentials['server_url'], + model_uid=self.credentials['model_uid'], ) client.callbacks = self.callbacks diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py index 3152499c86..a2412220b0 100644 --- a/api/core/model_providers/providers/xinference_provider.py +++ b/api/core/model_providers/providers/xinference_provider.py @@ -1,7 +1,8 @@ import json from typing import Type -from langchain.llms import Xinference +import requests +from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle from core.helper import encrypter from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding @@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.models.base import BaseProviderModel +from core.third_party.langchain.llms.xinference_llm import XinferenceLLM from models.provider import ProviderType @@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider): :param model_type: :return: """ - return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=2, default=1), - top_p=KwargRule[float](min=0, max=1, default=0.7), - presence_penalty=KwargRule[float](min=-2, max=2, default=0), - frequency_penalty=KwargRule[float](min=-2, max=2, default=0), - max_tokens=KwargRule[int](min=10, max=4000, default=256), - ) + credentials = self.get_model_credentials(model_name, model_type) + if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm": + return ModelKwargsRules( + temperature=KwargRule[float](min=0.01, max=2, default=1), + top_p=KwargRule[float](min=0, max=1, default=0.7), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](min=10, max=4000, default=256), + ) + elif credentials['model_format'] == "ggmlv3": + return ModelKwargsRules( + temperature=KwargRule[float](min=0.01, max=2, default=1), + top_p=KwargRule[float](min=0, max=1, default=0.7), + presence_penalty=KwargRule[float](min=-2, max=2, default=0), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0), + max_tokens=KwargRule[int](min=10, max=4000, default=256), + ) + else: + return ModelKwargsRules( + temperature=KwargRule[float](min=0.01, max=2, default=1), + top_p=KwargRule[float](min=0, max=1, default=0.7), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256), + ) + @classmethod def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): @@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider): 'model_uid': credentials['model_uid'], } - llm = Xinference( + llm = XinferenceLLM( **credential_kwargs ) - llm("ping", generate_config={'max_tokens': 10}) + llm("ping") except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider): :param credentials: :return: """ + extra_credentials = cls._get_extra_credentials(credentials) + credentials.update(extra_credentials) + credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) + return credentials def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: @@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider): return credentials + @classmethod + def _get_extra_credentials(self, credentials: dict) -> dict: + url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}" + response = requests.get(url) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get the model description, detail: {response.json()['detail']}" + ) + desc = response.json() + + extra_credentials = { + 'model_format': desc['model_format'], + } + if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]: + extra_credentials['model_handle_type'] = 'chatglm' + elif "generate" in desc["model_ability"]: + extra_credentials['model_handle_type'] = 'generate' + elif "chat" in desc["model_ability"]: + extra_credentials['model_handle_type'] = 'chat' + else: + raise NotImplementedError(f"Model handle type not supported.") + + return extra_credentials + @classmethod def is_provider_credentials_valid_or_raise(cls, credentials: dict): return diff --git a/api/core/third_party/langchain/llms/xinference_llm.py b/api/core/third_party/langchain/llms/xinference_llm.py new file mode 100644 index 0000000000..c69bfe2e4e --- /dev/null +++ b/api/core/third_party/langchain/llms/xinference_llm.py @@ -0,0 +1,132 @@ +from typing import Optional, List, Any, Union, Generator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms import Xinference +from langchain.llms.utils import enforce_stop_tokens +from xinference.client import RESTfulChatglmCppChatModelHandle, \ + RESTfulChatModelHandle, RESTfulGenerateModelHandle + + +class XinferenceLLM(Xinference): + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call the xinference model and return the output. + + Args: + prompt: The prompt to use for generation. + stop: Optional list of stop words to use when generating. + generate_config: Optional dictionary for the configuration used for + generation. + + Returns: + The generated string by the model. + """ + model = self.client.get_model(self.model_uid) + + if isinstance(model, RESTfulChatModelHandle): + generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) + + if stop: + generate_config["stop"] = stop + + if generate_config and generate_config.get("stream"): + combined_text_output = "" + for token in self._stream_generate( + model=model, + prompt=prompt, + run_manager=run_manager, + generate_config=generate_config, + ): + combined_text_output += token + return combined_text_output + else: + completion = model.chat(prompt=prompt, generate_config=generate_config) + return completion["choices"][0]["text"] + elif isinstance(model, RESTfulGenerateModelHandle): + generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) + + if stop: + generate_config["stop"] = stop + + if generate_config and generate_config.get("stream"): + combined_text_output = "" + for token in self._stream_generate( + model=model, + prompt=prompt, + run_manager=run_manager, + generate_config=generate_config, + ): + combined_text_output += token + return combined_text_output + + else: + completion = model.generate(prompt=prompt, generate_config=generate_config) + return completion["choices"][0]["text"] + elif isinstance(model, RESTfulChatglmCppChatModelHandle): + generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {}) + + if generate_config and generate_config.get("stream"): + combined_text_output = "" + for token in self._stream_generate( + model=model, + prompt=prompt, + run_manager=run_manager, + generate_config=generate_config, + ): + combined_text_output += token + completion = combined_text_output + else: + completion = model.chat(prompt=prompt, generate_config=generate_config) + completion = completion["choices"][0]["text"] + + if stop is not None: + completion = enforce_stop_tokens(completion, stop) + + return completion + + + def _stream_generate( + self, + model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"], + prompt: str, + run_manager: Optional[CallbackManagerForLLMRun] = None, + generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None, + ) -> Generator[str, None, None]: + """ + Args: + prompt: The prompt to use for generation. + model: The model used for generation. + stop: Optional list of stop words to use when generating. + generate_config: Optional dictionary for the configuration used for + generation. + + Yields: + A string token. + """ + if isinstance(model, RESTfulGenerateModelHandle): + streaming_response = model.generate( + prompt=prompt, generate_config=generate_config + ) + else: + streaming_response = model.chat( + prompt=prompt, generate_config=generate_config + ) + + for chunk in streaming_response: + if isinstance(chunk, dict): + choices = chunk.get("choices", []) + if choices: + choice = choices[0] + if isinstance(choice, dict): + token = choice.get("text", "") + log_probs = choice.get("logprobs") + if run_manager: + run_manager.on_llm_new_token( + token=token, verbose=self.verbose, log_probs=log_probs + ) + yield token diff --git a/api/tests/unit_tests/model_providers/test_xinference_provider.py b/api/tests/unit_tests/model_providers/test_xinference_provider.py index 4cf85dcce3..84a7985ba4 100644 --- a/api/tests/unit_tests/model_providers/test_xinference_provider.py +++ b/api/tests/unit_tests/model_providers/test_xinference_provider.py @@ -4,7 +4,6 @@ import json from core.model_providers.models.entity.model_params import ModelType from core.model_providers.providers.base import CredentialsValidateFailedError -from core.model_providers.providers.replicate_provider import ReplicateProvider from core.model_providers.providers.xinference_provider import XinferenceProvider from models.provider import ProviderType, Provider, ProviderModel @@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key): def test_is_credentials_valid_or_raise_valid(mocker): - mocker.patch('langchain.llms.xinference.Xinference._call', + mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call', return_value="abc") MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( @@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid(): @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) -def test_encrypt_model_credentials(mock_encrypt): +def test_encrypt_model_credentials(mock_encrypt, mocker): api_key = 'http://127.0.0.1:9997/' + + mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials', + return_value={ + 'model_handle_type': 'generate', + 'model_format': 'ggmlv3' + }) + result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( tenant_id='tenant_id', model_name='test_model_name',