From 2d9616c29c6b62bc4e756352a067d80a32d06b65 Mon Sep 17 00:00:00 2001 From: Uranus <109661872+UranusSeven@users.noreply.github.com> Date: Fri, 25 Aug 2023 18:15:05 +0800 Subject: [PATCH] fix: xinference last token being ignored (#1013) --- .../langchain/llms/xinference_llm.py | 92 +++++++++++-------- 1 file changed, 55 insertions(+), 37 deletions(-) diff --git a/api/core/third_party/langchain/llms/xinference_llm.py b/api/core/third_party/langchain/llms/xinference_llm.py index 7010e56d29..c65688adc7 100644 --- a/api/core/third_party/langchain/llms/xinference_llm.py +++ b/api/core/third_party/langchain/llms/xinference_llm.py @@ -3,17 +3,20 @@ from typing import Optional, List, Any, Union, Generator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import Xinference from langchain.llms.utils import enforce_stop_tokens -from xinference.client import RESTfulChatglmCppChatModelHandle, \ - RESTfulChatModelHandle, RESTfulGenerateModelHandle +from xinference.client import ( + RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, + RESTfulGenerateModelHandle, +) class XinferenceLLM(Xinference): def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call the xinference model and return the output. @@ -29,7 +32,9 @@ class XinferenceLLM(Xinference): model = self.client.get_model(self.model_uid) if isinstance(model, RESTfulChatModelHandle): - generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) + generate_config: "LlamaCppGenerateConfig" = kwargs.get( + "generate_config", {} + ) if stop: generate_config["stop"] = stop @@ -37,10 +42,10 @@ class XinferenceLLM(Xinference): if generate_config and generate_config.get("stream"): combined_text_output = "" for token in self._stream_generate( - model=model, - prompt=prompt, - run_manager=run_manager, - generate_config=generate_config, + model=model, + prompt=prompt, + run_manager=run_manager, + generate_config=generate_config, ): combined_text_output += token return combined_text_output @@ -48,7 +53,9 @@ class XinferenceLLM(Xinference): completion = model.chat(prompt=prompt, generate_config=generate_config) return completion["choices"][0]["message"]["content"] elif isinstance(model, RESTfulGenerateModelHandle): - generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) + generate_config: "LlamaCppGenerateConfig" = kwargs.get( + "generate_config", {} + ) if stop: generate_config["stop"] = stop @@ -56,27 +63,31 @@ class XinferenceLLM(Xinference): if generate_config and generate_config.get("stream"): combined_text_output = "" for token in self._stream_generate( - model=model, - prompt=prompt, - run_manager=run_manager, - generate_config=generate_config, + model=model, + prompt=prompt, + run_manager=run_manager, + generate_config=generate_config, ): combined_text_output += token return combined_text_output else: - completion = model.generate(prompt=prompt, generate_config=generate_config) + completion = model.generate( + prompt=prompt, generate_config=generate_config + ) return completion["choices"][0]["text"] elif isinstance(model, RESTfulChatglmCppChatModelHandle): - generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {}) + generate_config: "ChatglmCppGenerateConfig" = kwargs.get( + "generate_config", {} + ) if generate_config and generate_config.get("stream"): combined_text_output = "" for token in self._stream_generate( - model=model, - prompt=prompt, - run_manager=run_manager, - generate_config=generate_config, + model=model, + prompt=prompt, + run_manager=run_manager, + generate_config=generate_config, ): combined_text_output += token completion = combined_text_output @@ -90,12 +101,21 @@ class XinferenceLLM(Xinference): return completion def _stream_generate( - self, - model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"], - prompt: str, - run_manager: Optional[CallbackManagerForLLMRun] = None, - generate_config: Optional[ - Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None, + self, + model: Union[ + "RESTfulGenerateModelHandle", + "RESTfulChatModelHandle", + "RESTfulChatglmCppChatModelHandle", + ], + prompt: str, + run_manager: Optional[CallbackManagerForLLMRun] = None, + generate_config: Optional[ + Union[ + "LlamaCppGenerateConfig", + "PytorchGenerateConfig", + "ChatglmCppGenerateConfig", + ] + ] = None, ) -> Generator[str, None, None]: """ Args: @@ -108,7 +128,9 @@ class XinferenceLLM(Xinference): Yields: A string token. """ - if isinstance(model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)): + if isinstance( + model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle) + ): streaming_response = model.chat( prompt=prompt, generate_config=generate_config ) @@ -123,14 +145,10 @@ class XinferenceLLM(Xinference): if choices: choice = choices[0] if isinstance(choice, dict): - if 'finish_reason' in choice and choice['finish_reason'] \ - and choice['finish_reason'] in ['stop', 'length']: - break - - if 'text' in choice: + if "text" in choice: token = choice.get("text", "") - elif 'delta' in choice and 'content' in choice['delta']: - token = choice.get('delta').get('content') + elif "delta" in choice and "content" in choice["delta"]: + token = choice.get("delta").get("content") else: continue log_probs = choice.get("logprobs")