From 0796791de54e4b55c5098ef388347db96b71686e Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 26 Aug 2023 19:48:34 +0800 Subject: [PATCH] feat: hf inference endpoint stream support (#1028) --- .../models/llm/anthropic_model.py | 4 +- .../models/llm/azure_openai_model.py | 6 +- api/core/model_providers/models/llm/base.py | 8 +- .../models/llm/chatglm_model.py | 4 - .../models/llm/huggingface_hub_model.py | 17 +++- .../models/llm/openai_model.py | 4 +- .../models/llm/openllm_model.py | 4 - .../models/llm/replicate_model.py | 6 +- .../model_providers/models/llm/spark_model.py | 6 +- .../models/llm/tongyi_model.py | 4 +- .../models/llm/wenxin_model.py | 4 - .../models/llm/xinference_model.py | 4 +- .../llms/huggingface_endpoint_llm.py | 93 ++++++++++++++++++- .../test_huggingface_hub_provider.py | 4 +- 14 files changed, 128 insertions(+), 40 deletions(-) diff --git a/api/core/model_providers/models/llm/anthropic_model.py b/api/core/model_providers/models/llm/anthropic_model.py index dd6c17798d..62a9d992ba 100644 --- a/api/core/model_providers/models/llm/anthropic_model.py +++ b/api/core/model_providers/models/llm/anthropic_model.py @@ -75,7 +75,7 @@ class AnthropicModel(BaseLLM): else: return ex - @classmethod - def support_streaming(cls): + @property + def support_streaming(self): return True diff --git a/api/core/model_providers/models/llm/azure_openai_model.py b/api/core/model_providers/models/llm/azure_openai_model.py index a5a0d13d99..d97330ae3b 100644 --- a/api/core/model_providers/models/llm/azure_openai_model.py +++ b/api/core/model_providers/models/llm/azure_openai_model.py @@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM): else: return ex - @classmethod - def support_streaming(cls): - return True \ No newline at end of file + @property + def support_streaming(self): + return True diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index 8662e73275..4093db3387 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel): result = self._run( messages=messages, stop=stop, - callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None, + callbacks=callbacks if not (self.streaming and not self.support_streaming) else None, **kwargs ) except Exception as ex: @@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel): else: completion_content = result.generations[0][0].text - if self.streaming and not self.support_streaming(): + if self.streaming and not self.support_streaming: # use FakeLLM to simulate streaming when current model not support streaming but streaming is True prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT) fake_llm = FakeLLM( @@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel): else: self.client.callbacks.extend(callbacks) - @classmethod - def support_streaming(cls): + @property + def support_streaming(self): return False def get_prompt(self, mode: str, diff --git a/api/core/model_providers/models/llm/chatglm_model.py b/api/core/model_providers/models/llm/chatglm_model.py index f3ce9ceaf0..5f22cdf6af 100644 --- a/api/core/model_providers/models/llm/chatglm_model.py +++ b/api/core/model_providers/models/llm/chatglm_model.py @@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM): return LLMBadRequestError(f"ChatGLM: {str(ex)}") else: return ex - - @classmethod - def support_streaming(cls): - return False diff --git a/api/core/model_providers/models/llm/huggingface_hub_model.py b/api/core/model_providers/models/llm/huggingface_hub_model.py index fb381bf64d..e42b597f0b 100644 --- a/api/core/model_providers/models/llm/huggingface_hub_model.py +++ b/api/core/model_providers/models/llm/huggingface_hub_model.py @@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM): def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': + streaming = self.streaming + + if 'baichuan' in self.name.lower(): + streaming = False + client = HuggingFaceEndpointLLM( endpoint_url=self.credentials['huggingfacehub_endpoint_url'], task=self.credentials['task_type'], model_kwargs=provider_model_kwargs, huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], - callbacks=self.callbacks + callbacks=self.callbacks, + streaming=streaming ) else: client = HuggingFaceHub( @@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM): def handle_exceptions(self, ex: Exception) -> Exception: return LLMBadRequestError(f"Huggingface Hub: {str(ex)}") - @classmethod - def support_streaming(cls): - return False + @property + def support_streaming(self): + if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': + if 'baichuan' in self.name.lower(): + return False + return True diff --git a/api/core/model_providers/models/llm/openai_model.py b/api/core/model_providers/models/llm/openai_model.py index 91db37df6f..c63b041199 100644 --- a/api/core/model_providers/models/llm/openai_model.py +++ b/api/core/model_providers/models/llm/openai_model.py @@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM): else: return ex - @classmethod - def support_streaming(cls): + @property + def support_streaming(self): return True # def is_model_valid_or_raise(self): diff --git a/api/core/model_providers/models/llm/openllm_model.py b/api/core/model_providers/models/llm/openllm_model.py index 217d893c48..0ee6ce0f64 100644 --- a/api/core/model_providers/models/llm/openllm_model.py +++ b/api/core/model_providers/models/llm/openllm_model.py @@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM): def handle_exceptions(self, ex: Exception) -> Exception: return LLMBadRequestError(f"OpenLLM: {str(ex)}") - - @classmethod - def support_streaming(cls): - return False diff --git a/api/core/model_providers/models/llm/replicate_model.py b/api/core/model_providers/models/llm/replicate_model.py index e740440ac2..becc212ad9 100644 --- a/api/core/model_providers/models/llm/replicate_model.py +++ b/api/core/model_providers/models/llm/replicate_model.py @@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM): else: return ex - @classmethod - def support_streaming(cls): - return True \ No newline at end of file + @property + def support_streaming(self): + return True diff --git a/api/core/model_providers/models/llm/spark_model.py b/api/core/model_providers/models/llm/spark_model.py index a7b63ae058..ecbcb103e0 100644 --- a/api/core/model_providers/models/llm/spark_model.py +++ b/api/core/model_providers/models/llm/spark_model.py @@ -65,6 +65,6 @@ class SparkModel(BaseLLM): else: return ex - @classmethod - def support_streaming(cls): - return True \ No newline at end of file + @property + def support_streaming(self): + return True diff --git a/api/core/model_providers/models/llm/tongyi_model.py b/api/core/model_providers/models/llm/tongyi_model.py index 7138338a0c..a66606e16b 100644 --- a/api/core/model_providers/models/llm/tongyi_model.py +++ b/api/core/model_providers/models/llm/tongyi_model.py @@ -69,6 +69,6 @@ class TongyiModel(BaseLLM): else: return ex - @classmethod - def support_streaming(cls): + @property + def support_streaming(self): return True diff --git a/api/core/model_providers/models/llm/wenxin_model.py b/api/core/model_providers/models/llm/wenxin_model.py index 0f42ad27b5..3a9e534fac 100644 --- a/api/core/model_providers/models/llm/wenxin_model.py +++ b/api/core/model_providers/models/llm/wenxin_model.py @@ -57,7 +57,3 @@ class WenxinModel(BaseLLM): def handle_exceptions(self, ex: Exception) -> Exception: return LLMBadRequestError(f"Wenxin: {str(ex)}") - - @classmethod - def support_streaming(cls): - return False diff --git a/api/core/model_providers/models/llm/xinference_model.py b/api/core/model_providers/models/llm/xinference_model.py index a058a601b1..551450bec3 100644 --- a/api/core/model_providers/models/llm/xinference_model.py +++ b/api/core/model_providers/models/llm/xinference_model.py @@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM): def handle_exceptions(self, ex: Exception) -> Exception: return LLMBadRequestError(f"Xinference: {str(ex)}") - @classmethod - def support_streaming(cls): + @property + def support_streaming(self): return True diff --git a/api/core/third_party/langchain/llms/huggingface_endpoint_llm.py b/api/core/third_party/langchain/llms/huggingface_endpoint_llm.py index 71ee684e3d..0b2adba3c9 100644 --- a/api/core/third_party/langchain/llms/huggingface_endpoint_llm.py +++ b/api/core/third_party/langchain/llms/huggingface_endpoint_llm.py @@ -1,7 +1,11 @@ -from typing import Dict +from typing import Dict, Any, Optional, List, Iterable, Iterator +from huggingface_hub import InferenceClient +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.embeddings.huggingface_hub import VALID_TASKS from langchain.llms import HuggingFaceEndpoint -from pydantic import Extra, root_validator +from langchain.llms.utils import enforce_stop_tokens +from pydantic import root_validator from langchain.utils import get_from_dict_or_env @@ -27,6 +31,8 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint): huggingfacehub_api_token="my-api-key" ) """ + client: Any + streaming: bool = False @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: @@ -35,5 +41,88 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint): values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" ) + values['client'] = InferenceClient(values['endpoint_url'], token=huggingfacehub_api_token) + values["huggingfacehub_api_token"] = huggingfacehub_api_token return values + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to HuggingFace Hub's inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = hf("Tell me a joke.") + """ + _model_kwargs = self.model_kwargs or {} + + # payload samples + params = {**_model_kwargs, **kwargs} + + # generation parameter + gen_kwargs = { + **params, + 'stop_sequences': stop + } + + response = self.client.text_generation(prompt, stream=self.streaming, details=True, **gen_kwargs) + + if self.streaming and isinstance(response, Iterable): + combined_text_output = "" + for token in self._stream_response(response, run_manager): + combined_text_output += token + completion = combined_text_output + else: + completion = response.generated_text + + if self.task == "text-generation": + text = completion + # Remove prompt if included in generated text. + if text.startswith(prompt): + text = text[len(prompt) :] + elif self.task == "text2text-generation": + text = completion + else: + raise ValueError( + f"Got invalid task {self.task}, " + f"currently only {VALID_TASKS} are supported" + ) + + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to huggingface_hub. + text = enforce_stop_tokens(text, stop) + + return text + + def _stream_response( + self, + response: Iterable, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> Iterator[str]: + for r in response: + # skip special tokens + if r.token.special: + continue + + token = r.token.text + if run_manager: + run_manager.on_llm_new_token( + token=token, verbose=self.verbose, log_probs=None + ) + + # yield the generated token + yield token diff --git a/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py b/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py index 7f77d3c212..468d56038e 100644 --- a/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py +++ b/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py @@ -63,7 +63,7 @@ def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_i def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker): mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None) - mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc") + mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc") MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( model_name='test_model_name', @@ -71,8 +71,10 @@ def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker): credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL ) + def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker): mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None) + mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc") with pytest.raises(CredentialsValidateFailedError): MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(