diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a64abf1e2a..04c742f09f 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -191,7 +191,7 @@ class ModelInstance: self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) return cast( - int, + list[int], self._round_robin_invoke( function=self.model_type_instance.get_num_tokens, model=self.model, @@ -240,7 +240,7 @@ class ModelInstance: self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) return cast( - int, + list[int], self._round_robin_invoke( function=self.model_type_instance.get_num_tokens, model=self.model, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index a7e9bc809f..f9f5672b15 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,7 +1,7 @@ import json from collections import defaultdict from json import JSONDecodeError -from typing import Optional, cast +from typing import Any, Optional, cast from sqlalchemy.exc import IntegrityError @@ -350,7 +350,7 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid is True).all() + providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() # noqa provider_name_to_provider_records_dict = defaultdict(list) for provider in providers: @@ -369,7 +369,7 @@ class ProviderManager: # Get all provider model records of the workspace provider_models = ( db.session.query(ProviderModel) - .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid is True) + .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) # noqa .all() ) @@ -739,9 +739,9 @@ class ProviderManager: if not cached_provider_credentials: try: - provider_credentials = json.loads(provider_record.encrypted_config) + provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) except JSONDecodeError: - provider_credentials = {} + provider_credentials: dict[str, Any] = {} # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( @@ -758,7 +758,9 @@ class ProviderManager: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa + provider_credentials.get(variable, ""), + self.decoding_rsa_key, + self.decoding_cipher_rsa, ) except ValueError: pass diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 08b6e89806..484f6ea069 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -88,7 +88,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), DocumentSegment.status == "completed", - DocumentSegment.enabled is True, + DocumentSegment.enabled == True, # noqa DocumentSegment.index_node_id.in_(index_node_ids), ).all() @@ -109,8 +109,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = Document.query.filter( Document.id == segment.document_id, - Document.enabled is True, - Document.archived is False, + Document.enabled == True, # noqa + Document.archived == False, # noqa ).first() if dataset and document: source = { diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 1034bcfa1d..79226ffa52 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -7,7 +7,6 @@ from configs import dify_config from core.entities.model_entities import ( ModelWithProviderEntity, ProviderModelWithStatusEntity, - SimpleModelProviderEntity, ) from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject @@ -162,7 +161,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): Model with provider entity. """ - provider: SimpleModelProviderEntity + provider: SimpleProviderEntityResponse def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None: dump_model = model.model_dump()