diff --git a/api/core/model_providers/model_provider_factory.py b/api/core/model_providers/model_provider_factory.py index 0517676807..77f7757347 100644 --- a/api/core/model_providers/model_provider_factory.py +++ b/api/core/model_providers/model_provider_factory.py @@ -51,6 +51,9 @@ class ModelProviderFactory: elif provider_name == 'chatglm': from core.model_providers.providers.chatglm_provider import ChatGLMProvider return ChatGLMProvider + elif provider_name == 'baichuan': + from core.model_providers.providers.baichuan_provider import BaichuanProvider + return BaichuanProvider elif provider_name == 'azure_openai': from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider return AzureOpenAIProvider diff --git a/api/core/model_providers/models/llm/baichuan_model.py b/api/core/model_providers/models/llm/baichuan_model.py new file mode 100644 index 0000000000..e614547fa3 --- /dev/null +++ b/api/core/model_providers/models/llm/baichuan_model.py @@ -0,0 +1,61 @@ +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM + + +class BaichuanModel(BaseLLM): + model_mode: ModelMode = ModelMode.CHAT + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + return BaichuanChatLLM( + streaming=self.streaming, + callbacks=self.callbacks, + **self.credentials, + **provider_model_kwargs + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens_from_messages(prompts), 0) + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"Baichuan: {str(ex)}") + + @property + def support_streaming(self): + return True diff --git a/api/core/model_providers/providers/baichuan_provider.py b/api/core/model_providers/providers/baichuan_provider.py new file mode 100644 index 0000000000..12c475f92d --- /dev/null +++ b/api/core/model_providers/providers/baichuan_provider.py @@ -0,0 +1,167 @@ +import json +from json import JSONDecodeError +from typing import Type + +from langchain.schema import HumanMessage + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.llm.baichuan_model import BaichuanModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM +from models.provider import ProviderType + + +class BaichuanProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'baichuan' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + return [ + { + 'id': 'baichuan2-53b', + 'name': 'Baichuan2-53B', + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = BaichuanModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2), + top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](enabled=False), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'api_key' not in credentials: + raise CredentialsValidateFailedError('Baichuan api_key must be provided.') + + if 'secret_key' not in credentials: + raise CredentialsValidateFailedError('Baichuan secret_key must be provided.') + + try: + credential_kwargs = { + 'api_key': credentials['api_key'], + 'secret_key': credentials['secret_key'], + } + + llm = BaichuanChatLLM( + temperature=0, + **credential_kwargs + ) + + llm([HumanMessage(content='ping')]) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) + credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value: + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'api_key': None, + 'secret_key': None, + } + + if credentials['api_key']: + credentials['api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['api_key'] + ) + + if obfuscated: + credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) + + if credentials['secret_key']: + credentials['secret_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['secret_key'] + ) + + if obfuscated: + credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key']) + + return credentials + else: + return {} + + def should_deduct_quota(self): + return True + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/rules/_providers.json b/api/core/model_providers/rules/_providers.json index f004dc5261..92d56be824 100644 --- a/api/core/model_providers/rules/_providers.json +++ b/api/core/model_providers/rules/_providers.json @@ -7,10 +7,11 @@ "spark", "wenxin", "zhipuai", + "baichuan", "chatglm", "replicate", "huggingface_hub", "xinference", "openllm", "localai" -] \ No newline at end of file +] diff --git a/api/core/model_providers/rules/baichuan.json b/api/core/model_providers/rules/baichuan.json new file mode 100644 index 0000000000..237b0d24d2 --- /dev/null +++ b/api/core/model_providers/rules/baichuan.json @@ -0,0 +1,15 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "fixed", + "price_config": { + "baichuan2-53b": { + "prompt": "0.01", + "completion": "0.01", + "unit": "0.001", + "currency": "RMB" + } + } +} \ No newline at end of file diff --git a/api/core/third_party/langchain/llms/baichuan_llm.py b/api/core/third_party/langchain/llms/baichuan_llm.py new file mode 100644 index 0000000000..baa5dc0b10 --- /dev/null +++ b/api/core/third_party/langchain/llms/baichuan_llm.py @@ -0,0 +1,315 @@ +"""Wrapper around Baichuan APIs.""" +from __future__ import annotations + +import hashlib +import json +import logging +import time +from typing import ( + Any, + Dict, + List, + Optional, Iterator, +) + +import requests +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage +from langchain.schema.messages import AIMessageChunk +from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration +from pydantic import Extra, root_validator, BaseModel + +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +class BaichuanModelAPI(BaseModel): + api_key: str + secret_key: str + + base_url: str = "https://api.baichuan-ai.com/v1" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def do_request(self, model: str, messages: list[dict], parameters: dict, **kwargs: Any): + stream = 'stream' in kwargs and kwargs['stream'] + + url = self.base_url + ("/stream/chat" if stream else "/chat") + + data = { + "model": model, + "messages": messages, + "parameters": parameters + } + + json_data = json.dumps(data) + time_stamp = int(time.time()) + signature = self._calculate_md5(self.secret_key + json_data + str(time_stamp)) + + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + "X-BC-Request-Id": "your requestId", + "X-BC-Timestamp": str(time_stamp), + "X-BC-Signature": signature, + "X-BC-Sign-Algo": "MD5", + } + + response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=(5, 60)) + + if not response.ok: + raise ValueError(f"HTTP {response.status_code} error: {response.text}") + + if not stream: + json_response = response.json() + if json_response['code'] != 0: + raise ValueError( + f"API {json_response['code']}" + f" error: {json_response['msg']}" + ) + return json_response + else: + return response + + def _calculate_md5(self, input_string): + md5 = hashlib.md5() + md5.update(input_string.encode('utf-8')) + encrypted = md5.hexdigest() + return encrypted + + +class BaichuanChatLLM(BaseChatModel): + """Wrapper around Baichuan large language models. + To use, you should pass the api_key as a named parameter to the constructor. + Example: + .. code-block:: python + from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM + model = BaichuanChatLLM(model="", api_key="my-api-key", secret_key="my-secret-key") + """ + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"} + + @property + def lc_serializable(self) -> bool: + return True + + client: Any = None #: :meta private: + model: str = "Baichuan2-53B" + """Model name to use.""" + temperature: float = 0.3 + """A non-negative float that tunes the degree of randomness in generation.""" + top_p: float = 0.85 + """Total probability mass of tokens to consider at each step.""" + streaming: bool = False + """Whether to stream the response or return it all at once.""" + api_key: Optional[str] = None + secret_key: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["api_key"] = get_from_dict_or_env( + values, "api_key", "BAICHUAN_API_KEY" + ) + + values["secret_key"] = get_from_dict_or_env( + values, "secret_key", "BAICHUAN_SECRET_KEY" + ) + + values['client'] = BaichuanModelAPI( + api_key=values['api_key'], + secret_key=values['secret_key'] + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + "model": self.model, + "parameters": { + "temperature": self.temperature, + "top_p": self.top_p + } + } + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return self._default_params + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "baichuan" + + def _convert_message_to_dict(self, message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "user", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict["content"]) + elif role == "system": + return SystemMessage(content=_dict["content"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + def _create_message_dicts( + self, messages: List[BaseMessage] + ) -> List[Dict[str, Any]]: + dict_messages = [] + for m in messages: + message = self._convert_message_to_dict(m) + if dict_messages: + previous_message = dict_messages[-1] + if previous_message['role'] == message['role']: + dict_messages[-1]['content'] += f"\n{message['content']}" + else: + dict_messages.append(message) + else: + dict_messages.append(message) + + return dict_messages + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + llm_output: Optional[Dict] = None + for chunk in self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + + if chunk.generation_info is not None \ + and 'token_usage' in chunk.generation_info: + llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} + + assert generation is not None + return ChatResult(generations=[generation], llm_output=llm_output) + else: + message_dicts = self._create_message_dicts(messages) + params = self._default_params + params["messages"] = message_dicts + params.update(kwargs) + response = self.client.do_request(**params) + return self._create_chat_result(response) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts = self._create_message_dicts(messages) + params = self._default_params + params["messages"] = message_dicts + params.update(kwargs) + + for event in self.client.do_request(stream=True, **params).iter_lines(): + if event: + event = event.decode("utf-8") + + meta = json.loads(event) + + if meta['code'] != 0: + raise ValueError( + f"API {meta['code']}" + f" error: {meta['msg']}" + ) + + content = meta['data']['messages'][0]['content'] + + chunk_kwargs = { + 'message': AIMessageChunk(content=content), + } + + if 'usage' in meta: + token_usage = meta['usage'] + overall_token_usage = { + 'prompt_tokens': token_usage.get('prompt_tokens', 0), + 'completion_tokens': token_usage.get('answer_tokens', 0), + 'total_tokens': token_usage.get('total_tokens', 0) + } + chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage} + + yield ChatGenerationChunk(**chunk_kwargs) + if run_manager: + run_manager.on_llm_new_token(content) + + def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult: + data = response["data"] + generations = [] + for res in data["messages"]: + message = self._convert_dict_to_message(res) + gen = ChatGeneration( + message=message + ) + generations.append(gen) + usage = response.get("usage") + token_usage = { + 'prompt_tokens': usage.get('prompt_tokens', 0), + 'completion_tokens': usage.get('answer_tokens', 0), + 'total_tokens': usage.get('total_tokens', 0) + } + llm_output = {"token_usage": token_usage, "model_name": self.model} + return ChatResult(generations=generations, llm_output=llm_output) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in the messages. + + Useful for checking if an input will fit in a model's context window. + + Args: + messages: The message inputs to tokenize. + + Returns: + The sum of the number of tokens across the messages. + """ + return sum([self.get_num_tokens(m.content) for m in messages]) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + + return {"token_usage": token_usage, "model_name": self.model} diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index fed6c12e54..7d00ae0f6a 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -35,6 +35,10 @@ WENXIN_SECRET_KEY= # ZhipuAI Credentials ZHIPUAI_API_KEY= +# Baichuan Credentials +BAICHUAN_API_KEY= +BAICHUAN_SECRET_KEY= + # ChatGLM Credentials CHATGLM_API_BASE= diff --git a/api/tests/integration_tests/models/llm/test_baichuan_model.py b/api/tests/integration_tests/models/llm/test_baichuan_model.py new file mode 100644 index 0000000000..15610e1d1d --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_baichuan_model.py @@ -0,0 +1,81 @@ +import json +import os +from unittest.mock import patch + + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.baichuan_model import BaichuanModel +from core.model_providers.providers.baichuan_provider import BaichuanProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_api_key, valid_secret_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='baichuan', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({ + 'api_key': valid_api_key, + 'secret_key': valid_secret_key, + }), + is_valid=True, + ) + + +def get_mock_model(model_name: str, streaming: bool = False): + model_kwargs = ModelKwargs( + temperature=0.01, + ) + valid_api_key = os.environ['BAICHUAN_API_KEY'] + valid_secret_key = os.environ['BAICHUAN_SECRET_KEY'] + model_provider = BaichuanProvider(provider=get_mock_provider(valid_api_key, valid_secret_key)) + return BaichuanModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs, + streaming=streaming + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_get_num_tokens(mock_decrypt): + model = get_mock_model('baichuan2-53b') + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst > 0 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + + model = get_mock_model('baichuan2-53b') + messages = [ + PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') + ] + rst = model.run( + messages, + ) + assert len(rst.content) > 0 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_stream_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + + model = get_mock_model('baichuan2-53b', streaming=True) + messages = [ + PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') + ] + rst = model.run( + messages + ) + assert len(rst.content) > 0 diff --git a/api/tests/unit_tests/model_providers/test_baichuan_provider.py b/api/tests/unit_tests/model_providers/test_baichuan_provider.py new file mode 100644 index 0000000000..6d4b832405 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_baichuan_provider.py @@ -0,0 +1,97 @@ +import pytest +from unittest.mock import patch +import json + +from langchain.schema import ChatResult, ChatGeneration, AIMessage + +from core.model_providers.providers.baichuan_provider import BaichuanProvider +from core.model_providers.providers.base import CredentialsValidateFailedError +from models.provider import ProviderType, Provider + + +PROVIDER_NAME = 'baichuan' +MODEL_PROVIDER_CLASS = BaichuanProvider +VALIDATE_CREDENTIAL = { + 'api_key': 'valid_key', + 'secret_key': 'valid_key', +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_provider_credentials_valid_or_raise_valid(mocker): + mocker.patch('core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate', + return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))])) + + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) + + +def test_is_provider_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + credential = VALIDATE_CREDENTIAL.copy() + credential['api_key'] = 'invalid_key' + credential['secret_key'] = 'invalid_key' + + # raise CredentialsValidateFailedError if api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) + assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}' + assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] + encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result['api_key'] == 'valid_key' + assert result['secret_key'] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] + encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result['api_key'][6:-2] + secret_key_middle_token = result['secret_key'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0) + assert len(secret_key_middle_token) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0) + assert all(char == '*' for char in middle_token) + assert all(char == '*' for char in secret_key_middle_token)