mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-05-15 03:58:17 +08:00
fix: xinference last token being ignored (#1013)
This commit is contained in:
parent
915e26527b
commit
2d9616c29c
@ -3,17 +3,20 @@ from typing import Optional, List, Any, Union, Generator
|
|||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms import Xinference
|
from langchain.llms import Xinference
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
from xinference.client import (
|
||||||
RESTfulChatModelHandle, RESTfulGenerateModelHandle
|
RESTfulChatglmCppChatModelHandle,
|
||||||
|
RESTfulChatModelHandle,
|
||||||
|
RESTfulGenerateModelHandle,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XinferenceLLM(Xinference):
|
class XinferenceLLM(Xinference):
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call the xinference model and return the output.
|
"""Call the xinference model and return the output.
|
||||||
|
|
||||||
@ -29,7 +32,9 @@ class XinferenceLLM(Xinference):
|
|||||||
model = self.client.get_model(self.model_uid)
|
model = self.client.get_model(self.model_uid)
|
||||||
|
|
||||||
if isinstance(model, RESTfulChatModelHandle):
|
if isinstance(model, RESTfulChatModelHandle):
|
||||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
|
||||||
|
"generate_config", {}
|
||||||
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
generate_config["stop"] = stop
|
generate_config["stop"] = stop
|
||||||
@ -37,10 +42,10 @@ class XinferenceLLM(Xinference):
|
|||||||
if generate_config and generate_config.get("stream"):
|
if generate_config and generate_config.get("stream"):
|
||||||
combined_text_output = ""
|
combined_text_output = ""
|
||||||
for token in self._stream_generate(
|
for token in self._stream_generate(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
run_manager=run_manager,
|
run_manager=run_manager,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
):
|
):
|
||||||
combined_text_output += token
|
combined_text_output += token
|
||||||
return combined_text_output
|
return combined_text_output
|
||||||
@ -48,7 +53,9 @@ class XinferenceLLM(Xinference):
|
|||||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||||
return completion["choices"][0]["message"]["content"]
|
return completion["choices"][0]["message"]["content"]
|
||||||
elif isinstance(model, RESTfulGenerateModelHandle):
|
elif isinstance(model, RESTfulGenerateModelHandle):
|
||||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
|
||||||
|
"generate_config", {}
|
||||||
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
generate_config["stop"] = stop
|
generate_config["stop"] = stop
|
||||||
@ -56,27 +63,31 @@ class XinferenceLLM(Xinference):
|
|||||||
if generate_config and generate_config.get("stream"):
|
if generate_config and generate_config.get("stream"):
|
||||||
combined_text_output = ""
|
combined_text_output = ""
|
||||||
for token in self._stream_generate(
|
for token in self._stream_generate(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
run_manager=run_manager,
|
run_manager=run_manager,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
):
|
):
|
||||||
combined_text_output += token
|
combined_text_output += token
|
||||||
return combined_text_output
|
return combined_text_output
|
||||||
|
|
||||||
else:
|
else:
|
||||||
completion = model.generate(prompt=prompt, generate_config=generate_config)
|
completion = model.generate(
|
||||||
|
prompt=prompt, generate_config=generate_config
|
||||||
|
)
|
||||||
return completion["choices"][0]["text"]
|
return completion["choices"][0]["text"]
|
||||||
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
|
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
|
||||||
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
|
generate_config: "ChatglmCppGenerateConfig" = kwargs.get(
|
||||||
|
"generate_config", {}
|
||||||
|
)
|
||||||
|
|
||||||
if generate_config and generate_config.get("stream"):
|
if generate_config and generate_config.get("stream"):
|
||||||
combined_text_output = ""
|
combined_text_output = ""
|
||||||
for token in self._stream_generate(
|
for token in self._stream_generate(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
run_manager=run_manager,
|
run_manager=run_manager,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
):
|
):
|
||||||
combined_text_output += token
|
combined_text_output += token
|
||||||
completion = combined_text_output
|
completion = combined_text_output
|
||||||
@ -90,12 +101,21 @@ class XinferenceLLM(Xinference):
|
|||||||
return completion
|
return completion
|
||||||
|
|
||||||
def _stream_generate(
|
def _stream_generate(
|
||||||
self,
|
self,
|
||||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
model: Union[
|
||||||
prompt: str,
|
"RESTfulGenerateModelHandle",
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
"RESTfulChatModelHandle",
|
||||||
generate_config: Optional[
|
"RESTfulChatglmCppChatModelHandle",
|
||||||
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
],
|
||||||
|
prompt: str,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
generate_config: Optional[
|
||||||
|
Union[
|
||||||
|
"LlamaCppGenerateConfig",
|
||||||
|
"PytorchGenerateConfig",
|
||||||
|
"ChatglmCppGenerateConfig",
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -108,7 +128,9 @@ class XinferenceLLM(Xinference):
|
|||||||
Yields:
|
Yields:
|
||||||
A string token.
|
A string token.
|
||||||
"""
|
"""
|
||||||
if isinstance(model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)):
|
if isinstance(
|
||||||
|
model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
|
||||||
|
):
|
||||||
streaming_response = model.chat(
|
streaming_response = model.chat(
|
||||||
prompt=prompt, generate_config=generate_config
|
prompt=prompt, generate_config=generate_config
|
||||||
)
|
)
|
||||||
@ -123,14 +145,10 @@ class XinferenceLLM(Xinference):
|
|||||||
if choices:
|
if choices:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
if isinstance(choice, dict):
|
if isinstance(choice, dict):
|
||||||
if 'finish_reason' in choice and choice['finish_reason'] \
|
if "text" in choice:
|
||||||
and choice['finish_reason'] in ['stop', 'length']:
|
|
||||||
break
|
|
||||||
|
|
||||||
if 'text' in choice:
|
|
||||||
token = choice.get("text", "")
|
token = choice.get("text", "")
|
||||||
elif 'delta' in choice and 'content' in choice['delta']:
|
elif "delta" in choice and "content" in choice["delta"]:
|
||||||
token = choice.get('delta').get('content')
|
token = choice.get("delta").get("content")
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
log_probs = choice.get("logprobs")
|
log_probs = choice.get("logprobs")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user