mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 02:35:55 +08:00
feat: support weixin ernie-bot-4 and chat mode (#1375)
This commit is contained in:
parent
c039f4af83
commit
7c9b585a47
@ -6,17 +6,16 @@ from langchain.schema import LLMResult
|
|||||||
|
|
||||||
from core.model_providers.error import LLMBadRequestError
|
from core.model_providers.error import LLMBadRequestError
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
from core.model_providers.models.entity.message import PromptMessage
|
||||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||||
from core.third_party.langchain.llms.wenxin import Wenxin
|
from core.third_party.langchain.llms.wenxin import Wenxin
|
||||||
|
|
||||||
|
|
||||||
class WenxinModel(BaseLLM):
|
class WenxinModel(BaseLLM):
|
||||||
model_mode: ModelMode = ModelMode.COMPLETION
|
model_mode: ModelMode = ModelMode.CHAT
|
||||||
|
|
||||||
def _init_client(self) -> Any:
|
def _init_client(self) -> Any:
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||||
# TODO load price_config from configs(db)
|
|
||||||
return Wenxin(
|
return Wenxin(
|
||||||
model=self.name,
|
model=self.name,
|
||||||
streaming=self.streaming,
|
streaming=self.streaming,
|
||||||
@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return self._client.generate([prompts], stop, callbacks)
|
|
||||||
|
generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}
|
||||||
|
|
||||||
|
if 'functions' in kwargs:
|
||||||
|
generate_kwargs['functions'] = kwargs['functions']
|
||||||
|
|
||||||
|
return self._client.generate(**generate_kwargs)
|
||||||
|
|
||||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||||
"""
|
"""
|
||||||
@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return max(self._client.get_num_tokens(prompts), 0)
|
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||||
|
|
||||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
|
|||||||
|
|
||||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
return LLMBadRequestError(f"Wenxin: {str(ex)}")
|
return LLMBadRequestError(f"Wenxin: {str(ex)}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def support_streaming(self):
|
||||||
|
return True
|
||||||
|
@ -2,6 +2,8 @@ import json
|
|||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.model_providers.models.base import BaseProviderModel
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||||
@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
|
|||||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||||
if model_type == ModelType.TEXT_GENERATION:
|
if model_type == ModelType.TEXT_GENERATION:
|
||||||
return [
|
return [
|
||||||
|
{
|
||||||
|
'id': 'ernie-bot-4',
|
||||||
|
'name': 'ERNIE-Bot-4',
|
||||||
|
'mode': ModelMode.CHAT.value,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
'id': 'ernie-bot',
|
'id': 'ernie-bot',
|
||||||
'name': 'ERNIE-Bot',
|
'name': 'ERNIE-Bot',
|
||||||
'mode': ModelMode.COMPLETION.value,
|
'mode': ModelMode.CHAT.value,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'id': 'ernie-bot-turbo',
|
'id': 'ernie-bot-turbo',
|
||||||
'name': 'ERNIE-Bot-turbo',
|
'name': 'ERNIE-Bot-turbo',
|
||||||
'mode': ModelMode.COMPLETION.value,
|
'mode': ModelMode.CHAT.value,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'id': 'bloomz-7b',
|
'id': 'bloomz-7b',
|
||||||
'name': 'BLOOMZ-7B',
|
'name': 'BLOOMZ-7B',
|
||||||
'mode': ModelMode.COMPLETION.value,
|
'mode': ModelMode.CHAT.value,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
model_max_tokens = {
|
model_max_tokens = {
|
||||||
|
'ernie-bot-4': 4800,
|
||||||
'ernie-bot': 4800,
|
'ernie-bot': 4800,
|
||||||
'ernie-bot-turbo': 11200,
|
'ernie-bot-turbo': 11200,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
|
if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
|
||||||
return ModelKwargsRules(
|
return ModelKwargsRules(
|
||||||
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
|
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
|
||||||
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
|
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
|
||||||
@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
|
|||||||
**credential_kwargs
|
**credential_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
llm("ping")
|
llm([HumanMessage(content='ping')])
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@ -5,6 +5,12 @@
|
|||||||
"system_config": null,
|
"system_config": null,
|
||||||
"model_flexibility": "fixed",
|
"model_flexibility": "fixed",
|
||||||
"price_config": {
|
"price_config": {
|
||||||
|
"ernie-bot-4": {
|
||||||
|
"prompt": "0",
|
||||||
|
"completion": "0",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
},
|
||||||
"ernie-bot": {
|
"ernie-bot": {
|
||||||
"prompt": "0.012",
|
"prompt": "0.012",
|
||||||
"completion": "0.012",
|
"completion": "0.012",
|
||||||
|
198
api/core/third_party/langchain/llms/wenxin.py
vendored
198
api/core/third_party/langchain/llms/wenxin.py
vendored
@ -8,12 +8,15 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional, Iterator,
|
Optional, Iterator, Tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
||||||
|
from langchain.schema.messages import AIMessageChunk
|
||||||
|
from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
|
||||||
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
|
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
|
|||||||
raise ValueError(f"Wenxin Model name is required")
|
raise ValueError(f"Wenxin Model name is required")
|
||||||
|
|
||||||
model_url_map = {
|
model_url_map = {
|
||||||
|
'ernie-bot-4': 'completions_pro',
|
||||||
'ernie-bot': 'completions',
|
'ernie-bot': 'completions',
|
||||||
'ernie-bot-turbo': 'eb-instant',
|
'ernie-bot-turbo': 'eb-instant',
|
||||||
'bloomz-7b': 'bloomz_7b1',
|
'bloomz-7b': 'bloomz_7b1',
|
||||||
@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
|
|||||||
|
|
||||||
access_token = self.get_access_token()
|
access_token = self.get_access_token()
|
||||||
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
|
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
|
||||||
|
del request['model']
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
response = requests.post(api_url,
|
response = requests.post(api_url,
|
||||||
@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
|
|||||||
f"Wenxin API {json_response['error_code']}"
|
f"Wenxin API {json_response['error_code']}"
|
||||||
f" error: {json_response['error_msg']}"
|
f" error: {json_response['error_msg']}"
|
||||||
)
|
)
|
||||||
return json_response["result"]
|
return json_response
|
||||||
else:
|
else:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class Wenxin(LLM):
|
class Wenxin(BaseChatModel):
|
||||||
"""Wrapper around Wenxin large language models.
|
"""Wrapper around Wenxin large language models."""
|
||||||
To use, you should have the environment variable
|
|
||||||
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
|
@property
|
||||||
or pass them as a named parameter to the constructor.
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
Example:
|
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
|
||||||
.. code-block:: python
|
|
||||||
from langchain.llms.wenxin import Wenxin
|
@property
|
||||||
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
|
def lc_serializable(self) -> bool:
|
||||||
secret_key="my-group-id")
|
return True
|
||||||
"""
|
|
||||||
|
|
||||||
_client: _WenxinEndpointClient = PrivateAttr()
|
_client: _WenxinEndpointClient = PrivateAttr()
|
||||||
model: str = "ernie-bot"
|
model: str = "ernie-bot"
|
||||||
@ -161,64 +165,89 @@ class Wenxin(LLM):
|
|||||||
secret_key=self.secret_key,
|
secret_key=self.secret_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call(
|
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage]
|
||||||
|
) -> Tuple[List[Dict[str, Any]], str]:
|
||||||
|
dict_messages = []
|
||||||
|
system = None
|
||||||
|
for m in messages:
|
||||||
|
message = self._convert_message_to_dict(m)
|
||||||
|
if message['role'] == 'system':
|
||||||
|
if not system:
|
||||||
|
system = message['content']
|
||||||
|
else:
|
||||||
|
system += f"\n{message['content']}"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if dict_messages:
|
||||||
|
previous_message = dict_messages[-1]
|
||||||
|
if previous_message['role'] == message['role']:
|
||||||
|
dict_messages[-1]['content'] += f"\n{message['content']}"
|
||||||
|
else:
|
||||||
|
dict_messages.append(message)
|
||||||
|
else:
|
||||||
|
dict_messages.append(message)
|
||||||
|
|
||||||
|
return dict_messages, system
|
||||||
|
|
||||||
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> ChatResult:
|
||||||
r"""Call out to Wenxin's completion endpoint to chat
|
|
||||||
Args:
|
|
||||||
prompt: The prompt to pass into the model.
|
|
||||||
Returns:
|
|
||||||
The string generated by the model.
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
response = wenxin("Tell me a joke.")
|
|
||||||
"""
|
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
completion = ""
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
llm_output: Optional[Dict] = None
|
||||||
for chunk in self._stream(
|
for chunk in self._stream(
|
||||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
completion += chunk.text
|
if chunk.generation_info is not None \
|
||||||
|
and 'token_usage' in chunk.generation_info:
|
||||||
|
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||||
|
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||||
else:
|
else:
|
||||||
|
message_dicts, system = self._create_message_dicts(messages)
|
||||||
request = self._default_params
|
request = self._default_params
|
||||||
request["messages"] = [{"role": "user", "content": prompt}]
|
request["messages"] = message_dicts
|
||||||
|
if system:
|
||||||
|
request["system"] = system
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
completion = self._client.post(request)
|
response = self._client.post(request)
|
||||||
|
return self._create_chat_result(response)
|
||||||
if stop is not None:
|
|
||||||
completion = enforce_stop_tokens(completion, stop)
|
|
||||||
|
|
||||||
return completion
|
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
r"""Call wenxin completion_stream and return the resulting generator.
|
message_dicts, system = self._create_message_dicts(messages)
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompt to pass into the model.
|
|
||||||
stop: Optional list of stop words to use when generating.
|
|
||||||
Returns:
|
|
||||||
A generator representing the stream of tokens from Wenxin.
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
prompt = "Write a poem about a stream."
|
|
||||||
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
|
||||||
generator = wenxin.stream(prompt)
|
|
||||||
for token in generator:
|
|
||||||
yield token
|
|
||||||
"""
|
|
||||||
request = self._default_params
|
request = self._default_params
|
||||||
request["messages"] = [{"role": "user", "content": prompt}]
|
request["messages"] = message_dicts
|
||||||
|
if system:
|
||||||
|
request["system"] = system
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
|
|
||||||
for token in self._client.post(request).iter_lines():
|
for token in self._client.post(request).iter_lines():
|
||||||
@ -228,12 +257,18 @@ class Wenxin(LLM):
|
|||||||
if token.startswith('data:'):
|
if token.startswith('data:'):
|
||||||
completion = json.loads(token[5:])
|
completion = json.loads(token[5:])
|
||||||
|
|
||||||
yield GenerationChunk(text=completion['result'])
|
chunk_dict = {
|
||||||
if run_manager:
|
'message': AIMessageChunk(content=completion['result']),
|
||||||
run_manager.on_llm_new_token(completion['result'])
|
}
|
||||||
|
|
||||||
if completion['is_end']:
|
if completion['is_end']:
|
||||||
break
|
token_usage = completion['usage']
|
||||||
|
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
|
||||||
|
chunk_dict['generation_info'] = dict({'token_usage': token_usage})
|
||||||
|
|
||||||
|
yield ChatGenerationChunk(**chunk_dict)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(completion['result'])
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
json_response = json.loads(token)
|
json_response = json.loads(token)
|
||||||
@ -245,3 +280,40 @@ class Wenxin(LLM):
|
|||||||
f" error: {json_response['error_msg']}, "
|
f" error: {json_response['error_msg']}, "
|
||||||
f"please confirm if the model you have chosen is already paid for."
|
f"please confirm if the model you have chosen is already paid for."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
||||||
|
generations = [ChatGeneration(
|
||||||
|
message=AIMessage(content=response['result']),
|
||||||
|
)]
|
||||||
|
token_usage = response.get("usage")
|
||||||
|
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
|
||||||
|
|
||||||
|
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
"""Get the number of tokens in the messages.
|
||||||
|
|
||||||
|
Useful for checking if an input will fit in a model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The message inputs to tokenize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sum of the number of tokens across the messages.
|
||||||
|
"""
|
||||||
|
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
# Happens in streaming
|
||||||
|
continue
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
for k, v in token_usage.items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||||
|
@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
|
|||||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model('ernie-bot')
|
model = get_mock_model('ernie-bot')
|
||||||
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
|
messages = [PromptMessage(type=MessageType.USER, content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
|
||||||
rst = model.run(
|
rst = model.run(
|
||||||
messages,
|
messages
|
||||||
stop=['\nHuman:'],
|
|
||||||
)
|
)
|
||||||
assert len(rst.content) > 0
|
assert len(rst.content) > 0
|
||||||
|
@ -2,6 +2,8 @@ import pytest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from langchain.schema import AIMessage, ChatGeneration, ChatResult
|
||||||
|
|
||||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
||||||
from models.provider import ProviderType, Provider
|
from models.provider import ProviderType, Provider
|
||||||
@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
|||||||
|
|
||||||
|
|
||||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||||
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc")
|
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._generate',
|
||||||
|
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
|
||||||
|
|
||||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user