mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 22:29:03 +08:00
fix: xinference chat support (#939)
This commit is contained in:
parent
f53242c081
commit
e0a48c4972
@ -1,13 +1,13 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import Xinference
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
|
||||
|
||||
class XinferenceModel(BaseLLM):
|
||||
@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
client = Xinference(
|
||||
**self.credentials,
|
||||
client = XinferenceLLM(
|
||||
server_url=self.credentials['server_url'],
|
||||
model_uid=self.credentials['model_uid'],
|
||||
)
|
||||
|
||||
client.callbacks = self.callbacks
|
||||
|
@ -1,7 +1,8 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import Xinference
|
||||
import requests
|
||||
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
credentials = self.get_model_credentials(model_name, model_type)
|
||||
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
elif credentials['model_format'] == "ggmlv3":
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
else:
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
|
||||
'model_uid': credentials['model_uid'],
|
||||
}
|
||||
|
||||
llm = Xinference(
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping", generate_config={'max_tokens': 10})
|
||||
llm("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
|
||||
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||
|
||||
return credentials
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
|
||||
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _get_extra_credentials(self, credentials: dict) -> dict:
|
||||
url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get the model description, detail: {response.json()['detail']}"
|
||||
)
|
||||
desc = response.json()
|
||||
|
||||
extra_credentials = {
|
||||
'model_format': desc['model_format'],
|
||||
}
|
||||
if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
|
||||
extra_credentials['model_handle_type'] = 'chatglm'
|
||||
elif "generate" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'generate'
|
||||
elif "chat" in desc["model_ability"]:
|
||||
extra_credentials['model_handle_type'] = 'chat'
|
||||
else:
|
||||
raise NotImplementedError(f"Model handle type not supported.")
|
||||
|
||||
return extra_credentials
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
return
|
||||
|
132
api/core/third_party/langchain/llms/xinference_llm.py
vendored
Normal file
132
api/core/third_party/langchain/llms/xinference_llm.py
vendored
Normal file
@ -0,0 +1,132 @@
|
||||
from typing import Optional, List, Any, Union, Generator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Xinference
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
||||
RESTfulChatModelHandle, RESTfulGenerateModelHandle
|
||||
|
||||
|
||||
class XinferenceLLM(Xinference):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Returns:
|
||||
The generated string by the model.
|
||||
"""
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
if isinstance(model, RESTfulChatModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
else:
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["text"]
|
||||
elif isinstance(model, RESTfulGenerateModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
|
||||
else:
|
||||
completion = model.generate(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["text"]
|
||||
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
|
||||
generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
else:
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
completion = completion["choices"][0]["text"]
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
model: The model used for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Yields:
|
||||
A string token.
|
||||
"""
|
||||
if isinstance(model, RESTfulGenerateModelHandle):
|
||||
streaming_response = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
else:
|
||||
streaming_response = model.chat(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
|
||||
for chunk in streaming_response:
|
||||
if isinstance(chunk, dict):
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
token = choice.get("text", "")
|
||||
log_probs = choice.get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield token
|
@ -4,7 +4,6 @@ import json
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
||||
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||
from models.provider import ProviderType, Provider, ProviderModel
|
||||
|
||||
@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.llms.xinference.Xinference._call',
|
||||
mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
|
||||
return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_model_credentials(mock_encrypt):
|
||||
def test_encrypt_model_credentials(mock_encrypt, mocker):
|
||||
api_key = 'http://127.0.0.1:9997/'
|
||||
|
||||
mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
|
||||
return_value={
|
||||
'model_handle_type': 'generate',
|
||||
'model_format': 'ggmlv3'
|
||||
})
|
||||
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||
tenant_id='tenant_id',
|
||||
model_name='test_model_name',
|
||||
|
Loading…
x
Reference in New Issue
Block a user