From 1779cea6e3a06e9f997e894d1d60bc746257d99e Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 12 Jan 2024 16:48:38 +0800 Subject: [PATCH] fix: model provider credentials null value validate failed (#2009) --- api/core/entities/provider_configuration.py | 15 ++------------- .../model_providers/model_provider_factory.py | 16 ++++++++++------ .../schema_validators/common_validator.py | 2 +- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index cb31d01e99..e99b238963 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -165,7 +165,7 @@ class ProviderConfiguration(BaseModel): if value == '[__HIDDEN__]' and key in original_credentials: credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - model_provider_factory.provider_credentials_validate( + credentials = model_provider_factory.provider_credentials_validate( self.provider.provider, credentials ) @@ -308,24 +308,13 @@ class ProviderConfiguration(BaseModel): if value == '[__HIDDEN__]' and key in original_credentials: credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - model_provider_factory.model_credentials_validate( + credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) - model_schema = ( - model_provider_factory.get_provider_instance(self.provider.provider) - .get_model_instance(model_type)._get_customizable_model_schema( - model=model, - credentials=credentials - ) - ) - - if model_schema: - credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema)) - for key, value in credentials.items(): if key in provider_credential_secret_variables: credentials[key] = encrypter.encrypt_token(self.tenant_id, value) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 375017c563..06932b018d 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -61,7 +61,7 @@ class ModelProviderFactory: # return providers return providers - def provider_credentials_validate(self, provider: str, credentials: dict) -> None: + def provider_credentials_validate(self, provider: str, credentials: dict) -> dict: """ Validate provider credentials @@ -80,13 +80,15 @@ class ModelProviderFactory: # validate provider credential schema validator = ProviderCredentialSchemaValidator(provider_credential_schema) - validator.validate_and_filter(credentials) + filtered_credentials = validator.validate_and_filter(credentials) # validate the credentials, raise exception if validation failed - model_provider_instance.validate_provider_credentials(credentials) + model_provider_instance.validate_provider_credentials(filtered_credentials) + + return filtered_credentials def model_credentials_validate(self, provider: str, model_type: ModelType, - model: str, credentials: dict) -> None: + model: str, credentials: dict) -> dict: """ Validate model credentials @@ -107,13 +109,15 @@ class ModelProviderFactory: # validate model credential schema validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) - validator.validate_and_filter(credentials) + filtered_credentials = validator.validate_and_filter(credentials) # get model instance of the model type model_instance = model_provider_instance.get_model_instance(model_type) # call validate_credentials method of model type to validate credentials, raise exception if validation failed - model_instance.validate_credentials(model, credentials) + model_instance.validate_credentials(model, filtered_credentials) + + return filtered_credentials def get_models(self, provider: Optional[str] = None, diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 3e6f3526ef..fe705d6943 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -46,7 +46,7 @@ class CommonValidator: :return: validated credential form schema value """ # If the variable does not exist in credentials - if credential_form_schema.variable not in credentials: + if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: raise ValueError(f'Variable {credential_form_schema.variable} is required')