feat: add zhipuai (#1188)

This commit is contained in:
takatost 2023-09-18 17:32:31 +08:00 committed by GitHub
parent c8bd76cd66
commit 827c97f0d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 1087 additions and 111 deletions

View File

@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default
'default': v.default,
'precision': v.precision
}
for k, v in vars(parameter_rules).items()
}
@ -290,10 +291,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
args = parser.parse_args()
provider_service = ProviderService()
result = provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name
provider_name=provider_name,
token=args['token']
)
return result

View File

@ -63,7 +63,18 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
if 'completion_tokens' in response.llm_output['token_usage']:
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)

View File

@ -2,13 +2,8 @@ import enum
import logging
from typing import List, Dict, Optional, Any
import openai
from flask import current_app
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from openai import InvalidRequestError
from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \
AuthenticationError, OpenAIError
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
@ -86,6 +81,12 @@ class SensitiveWordAvoidanceChain(Chain):
result = self._check_moderation(text)
if not result:
raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response)
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
class SensitiveWordAvoidanceError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@ -7,6 +7,7 @@ from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
@ -76,28 +77,53 @@ class Completion:
app_model_config=app_model_config
)
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# run the final llm
try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
retriever_from=retriever_from
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# run the final llm
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
@ -106,7 +132,8 @@ class Completion:
inputs=inputs,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory
memory=memory,
fake_response=fake_response
)
except ConversationTaskStoppedException:
return
@ -121,14 +148,8 @@ class Completion:
inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,

View File

@ -1,32 +1,34 @@
import logging
import openai
from flask import current_app
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
from models.provider import ProviderType
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']:
moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',')
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and model_provider.provider_name in moderation_providers:
and model_provider.provider_name in hosted_config.moderation.providers:
# 2000 text per chunk
length = 2000
chunks = [text[i:i + length] for i in range(0, len(text), length)]
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
try:
moderation_result = openai.Moderation.create(input=chunks,
api_key=current_app.config['HOSTED_OPENAI_API_KEY'])
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for result in moderation_result.results:
if result['flagged'] is True:
return False
for text_chunk in chunks:
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@ -45,6 +45,9 @@ class ModelProviderFactory:
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'zhipuai':
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
return ZhipuAIProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider

View File

@ -0,0 +1,22 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
class ZhipuAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ZhipuAIEmbeddings(
model=name,
**credentials,
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")

View File

@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
precision: Optional[int] = None
class ModelKwargsRules(BaseModel):

View File

@ -0,0 +1,61 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
class ZhipuAIModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ZhipuAIChatLLM(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@ -23,14 +23,18 @@ class OpenAIModeration(BaseModeration):
# 2000 text per chunk
length = 2000
chunks = [text[i:i + length] for i in range(0, len(text), length)]
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
moderation_result = self._client.create(input=chunks,
api_key=credentials['openai_api_key'])
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for result in moderation_result.results:
if result['flagged'] is True:
return False
for text_chunk in chunks:
moderation_result = self._client.create(input=text_chunk,
api_key=credentials['openai_api_key'])
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
)
@classmethod

View File

@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
model_credentials = self.get_model_credentials(model_name, model_type)
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
model_credentials['base_model_name'],
4097
), default=16),
), default=16, precision=0),
)
@classmethod

View File

@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
)
@classmethod

View File

@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
hosted_model_providers = HostedModelProviders()
class HostedModerationConfig(BaseModel):
enabled: bool = False
providers: list[str] = []
class HostedConfig(BaseModel):
moderation = HostedModerationConfig()
hosted_config = HostedConfig()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
@ -78,3 +90,9 @@ def init_app(app: Flask):
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
hosted_config.moderation = HostedModerationConfig(
enabled=app.config.get("HOSTED_MODERATION_ENABLED"),
providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',')
)

View File

@ -47,11 +47,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0),
)
@classmethod

View File

@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0),
)
@classmethod

View File

@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
top_p=KwargRule[float](min=0, max=1, default=0.95),
temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0),
)
@classmethod

View File

@ -133,11 +133,11 @@ class OpenAIProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0),
)
@classmethod

View File

@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0),
)
@classmethod

View File

@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
default=float(value.get('default')) if value.get('default') is not None else None,
precision = 2
)
if key == 'temperature':
model_kwargs_rules.temperature = kwarg_rule
@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
default=int(value.get('default')) if value.get('default') is not None else 500,
precision = 0
)
return model_kwargs_rules

View File

@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.5),
temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2),
top_p=KwargRule[float](enabled=False),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0),
)
@classmethod

View File

@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8),
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
)
@classmethod

View File

@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider):
"""
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),

View File

@ -53,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)

View File

@ -0,0 +1,176 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
from models.provider import ProviderType, ProviderQuotaType
class ZhipuAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'zhipuai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
}
]
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'text_embedding',
'name': 'text_embedding',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = ZhipuAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = ZhipuAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key']
}
llm = ZhipuAIChatLLM(
temperature=0.01,
**credential_kwargs
)
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value \
or (self.provider.provider_type == ProviderType.SYSTEM.value
and self.provider.quota_type == ProviderQuotaType.FREE.value):
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -6,6 +6,7 @@
"tongyi",
"spark",
"wenxin",
"zhipuai",
"chatglm",
"replicate",
"huggingface_hub",

View File

@ -0,0 +1,44 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"free"
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed",
"price_config": {
"chatglm_pro": {
"prompt": "0.01",
"completion": "0.01",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_std": {
"prompt": "0.005",
"completion": "0.005",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_lite": {
"prompt": "0.002",
"completion": "0.002",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_lite_32k": {
"prompt": "0.0004",
"completion": "0.0004",
"unit": "0.001",
"currency": "RMB"
},
"text_embedding": {
"completion": "0",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@ -0,0 +1,64 @@
"""Wrapper around ZhipuAI embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""Wrapper around ZhipuAI embedding models.
1024 dimensions.
"""
client: Any #: :meta private:
model: str
"""Model name to use."""
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "ZHIPUAI_API_KEY"
)
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
response = self.client.invoke(model=self.model, prompt=text)
data = response["data"]
embeddings.append(data.get('embedding'))
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

View File

@ -0,0 +1,315 @@
"""Wrapper around ZhipuAI APIs."""
from __future__ import annotations
import json
import logging
import posixpath
from typing import (
Any,
Dict,
List,
Optional, Iterator, Sequence,
)
import zhipuai
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import Extra, root_validator, BaseModel
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.utils import get_from_dict_or_env
from zhipuai.model_api.api import InvokeType
from zhipuai.utils import jwt_token
from zhipuai.utils.http_client import post, stream
from zhipuai.utils.sse_client import SSEClient
logger = logging.getLogger(__name__)
class ZhipuModelAPI(BaseModel):
base_url: str
api_key: str
api_timeout_seconds = 60
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SYNC)
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
if not response['success']:
raise ValueError(
f"Error Code: {response['code']}, Message: {response['msg']} "
)
return response
def sse_invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SSE)
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
return SSEClient(data)
def _build_api_url(self, kwargs, *path):
if kwargs:
if "model" not in kwargs:
raise Exception("model param missed")
model = kwargs.pop("model")
else:
model = "-"
return posixpath.join(self.base_url, model, *path)
def _generate_token(self):
if not self.api_key:
raise Exception(
"api_key not provided, you could provide it."
)
try:
return jwt_token.generate_token(self.api_key)
except Exception:
raise ValueError(
f"Your api_key is invalid, please check it."
)
class ZhipuAIChatLLM(BaseChatModel):
"""Wrapper around ZhipuAI large language models.
To use, you should pass the api_key as a named parameter to the constructor.
Example:
.. code-block:: python
from core.third_party.langchain.llms.zhipuai import ZhipuAI
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any = None #: :meta private:
model: str = "chatglm_lite"
"""Model name to use."""
temperature: float = 0.95
"""A non-negative float that tunes the degree of randomness in generation."""
top_p: float = 0.7
"""Total probability mass of tokens to consider at each step."""
streaming: bool = False
"""Whether to stream the response or return it all at once."""
api_key: Optional[str] = None
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "ZHIPUAI_API_KEY"
)
if 'test' in values['base_url']:
values['model'] = 'chatglm_130b_test'
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "zhipuai"
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
dict_messages = []
for m in messages:
message = self._convert_message_to_dict(m)
if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)
return dict_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
continue
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts = self._create_message_dicts(messages)
request = self._default_params
request["prompt"] = message_dicts
request.update(kwargs)
response = self.client.invoke(**request)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages)
request = self._default_params
request["prompt"] = message_dicts
request.update(kwargs)
for event in self.client.sse_invoke(incremental=True, **request).events():
if event.event == "add":
yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
if run_manager:
run_manager.on_llm_new_token(event.data)
elif event.event == "error" or event.event == "interrupted":
raise ValueError(
f"{event.data}"
)
elif event.event == "finish":
meta = json.loads(event.meta)
token_usage = meta['usage']
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
yield ChatGenerationChunk(
message=AIMessageChunk(content=event.data),
generation_info=dict({'token_usage': token_usage})
)
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
data = response["data"]
generations = []
for res in data["choices"]:
message = self._convert_dict_to_message(res)
gen = ChatGeneration(
message=message
)
generations.append(gen)
token_usage = data.get("usage")
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
# def get_token_ids(self, text: str) -> List[int]:
# """Return the ordered ids of the tokens in a text.
#
# Args:
# text: The string input to tokenize.
#
# Returns:
# A list of ids corresponding to the tokens in the text, in order they occur
# in the text.
# """
# from core.third_party.transformers.Token import ChatGLMTokenizer
#
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
# return tokenizer.encode(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}

View File

@ -50,4 +50,5 @@ transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.4.2
safetensors==0.3.2
safetensors==0.3.2
zhipuai==1.0.7

View File

@ -548,7 +548,7 @@ class ProviderService:
'result': 'success'
}
def free_quota_qualification_verify(self, tenant_id: str, provider_name: str):
def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_url = api_base_url + '/api/v1/providers/qualification-verify'
@ -557,8 +557,11 @@ class ProviderService:
'Content-Type': 'application/json',
'Authorization': f"Bearer {api_key}"
}
json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
if token:
json_data['token'] = token
response = requests.post(api_url, headers=headers,
json={'workspace_id': tenant_id, 'provider_name': provider_name})
json=json_data)
if not response.ok:
logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
raise ValueError(f"Error: {response.status_code} ")

View File

@ -31,6 +31,9 @@ TONGYI_DASHSCOPE_API_KEY=
WENXIN_API_KEY=
WENXIN_SECRET_KEY=
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
# ChatGLM Credentials
CHATGLM_API_BASE=

View File

@ -0,0 +1,50 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='zhipuai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_embedding_model():
model_name = 'text_embedding'
valid_api_key = os.environ['ZHIPUAI_API_KEY']
provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
return ZhipuAIEmbedding(
model_provider=provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 1024
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_doc_embedding(mock_decrypt):
embedding_model = get_mock_embedding_model()
rst = embedding_model.client.embed_documents(['test', 'test2'])
assert isinstance(rst, list)
assert len(rst[0]) == 1024

View File

@ -0,0 +1,79 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='zhipuai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key
}),
is_valid=True,
)
def get_mock_model(model_name: str, streaming: bool = False):
model_kwargs = ModelKwargs(
temperature=0.01,
)
valid_api_key = os.environ['ZHIPUAI_API_KEY']
model_provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
return ZhipuAIModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('chatglm_lite')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('chatglm_lite')
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
)
assert len(rst.content) > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_stream_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('chatglm_lite', streaming=True)
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
)
assert len(rst.content) > 0

View File

@ -39,7 +39,7 @@ def test_is_provider_credentials_valid_or_raise_invalid():
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
del credential['api_key']
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):

View File

@ -0,0 +1,88 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import ChatResult, ChatGeneration, AIMessage
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'zhipuai'
MODEL_PROVIDER_CLASS = ZhipuAIProvider
VALIDATE_CREDENTIAL = {
'api_key': 'valid_key',
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert all(char == '*' for char in middle_token)