fix: xinference last token being ignored (#1013)

This commit is contained in:
Uranus 2023-08-25 18:15:05 +08:00 committed by GitHub
parent 915e26527b
commit 2d9616c29c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,17 +3,20 @@ from typing import Optional, List, Any, Union, Generator
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference from langchain.llms import Xinference
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from xinference.client import RESTfulChatglmCppChatModelHandle, \ from xinference.client import (
RESTfulChatModelHandle, RESTfulGenerateModelHandle RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
)
class XinferenceLLM(Xinference): class XinferenceLLM(Xinference):
def _call( def _call(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Call the xinference model and return the output. """Call the xinference model and return the output.
@ -29,7 +32,9 @@ class XinferenceLLM(Xinference):
model = self.client.get_model(self.model_uid) model = self.client.get_model(self.model_uid)
if isinstance(model, RESTfulChatModelHandle): if isinstance(model, RESTfulChatModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) generate_config: "LlamaCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if stop: if stop:
generate_config["stop"] = stop generate_config["stop"] = stop
@ -37,10 +42,10 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"): if generate_config and generate_config.get("stream"):
combined_text_output = "" combined_text_output = ""
for token in self._stream_generate( for token in self._stream_generate(
model=model, model=model,
prompt=prompt, prompt=prompt,
run_manager=run_manager, run_manager=run_manager,
generate_config=generate_config, generate_config=generate_config,
): ):
combined_text_output += token combined_text_output += token
return combined_text_output return combined_text_output
@ -48,7 +53,9 @@ class XinferenceLLM(Xinference):
completion = model.chat(prompt=prompt, generate_config=generate_config) completion = model.chat(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["message"]["content"] return completion["choices"][0]["message"]["content"]
elif isinstance(model, RESTfulGenerateModelHandle): elif isinstance(model, RESTfulGenerateModelHandle):
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) generate_config: "LlamaCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if stop: if stop:
generate_config["stop"] = stop generate_config["stop"] = stop
@ -56,27 +63,31 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"): if generate_config and generate_config.get("stream"):
combined_text_output = "" combined_text_output = ""
for token in self._stream_generate( for token in self._stream_generate(
model=model, model=model,
prompt=prompt, prompt=prompt,
run_manager=run_manager, run_manager=run_manager,
generate_config=generate_config, generate_config=generate_config,
): ):
combined_text_output += token combined_text_output += token
return combined_text_output return combined_text_output
else: else:
completion = model.generate(prompt=prompt, generate_config=generate_config) completion = model.generate(
prompt=prompt, generate_config=generate_config
)
return completion["choices"][0]["text"] return completion["choices"][0]["text"]
elif isinstance(model, RESTfulChatglmCppChatModelHandle): elif isinstance(model, RESTfulChatglmCppChatModelHandle):
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {}) generate_config: "ChatglmCppGenerateConfig" = kwargs.get(
"generate_config", {}
)
if generate_config and generate_config.get("stream"): if generate_config and generate_config.get("stream"):
combined_text_output = "" combined_text_output = ""
for token in self._stream_generate( for token in self._stream_generate(
model=model, model=model,
prompt=prompt, prompt=prompt,
run_manager=run_manager, run_manager=run_manager,
generate_config=generate_config, generate_config=generate_config,
): ):
combined_text_output += token combined_text_output += token
completion = combined_text_output completion = combined_text_output
@ -90,12 +101,21 @@ class XinferenceLLM(Xinference):
return completion return completion
def _stream_generate( def _stream_generate(
self, self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"], model: Union[
prompt: str, "RESTfulGenerateModelHandle",
run_manager: Optional[CallbackManagerForLLMRun] = None, "RESTfulChatModelHandle",
generate_config: Optional[ "RESTfulChatglmCppChatModelHandle",
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None, ],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[
Union[
"LlamaCppGenerateConfig",
"PytorchGenerateConfig",
"ChatglmCppGenerateConfig",
]
] = None,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
""" """
Args: Args:
@ -108,7 +128,9 @@ class XinferenceLLM(Xinference):
Yields: Yields:
A string token. A string token.
""" """
if isinstance(model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)): if isinstance(
model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
):
streaming_response = model.chat( streaming_response = model.chat(
prompt=prompt, generate_config=generate_config prompt=prompt, generate_config=generate_config
) )
@ -123,14 +145,10 @@ class XinferenceLLM(Xinference):
if choices: if choices:
choice = choices[0] choice = choices[0]
if isinstance(choice, dict): if isinstance(choice, dict):
if 'finish_reason' in choice and choice['finish_reason'] \ if "text" in choice:
and choice['finish_reason'] in ['stop', 'length']:
break
if 'text' in choice:
token = choice.get("text", "") token = choice.get("text", "")
elif 'delta' in choice and 'content' in choice['delta']: elif "delta" in choice and "content" in choice["delta"]:
token = choice.get('delta').get('content') token = choice.get("delta").get("content")
else: else:
continue continue
log_probs = choice.get("logprobs") log_probs = choice.get("logprobs")