fix: errors occrus during rebasing

This commit is contained in:
Yeuoly 2024-12-26 13:20:12 +08:00
parent 80d8e47e42
commit e231cf2c48
4 changed files with 14 additions and 13 deletions

View File

@ -191,7 +191,7 @@ class ModelInstance:
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast( return cast(
int, list[int],
self._round_robin_invoke( self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,
@ -240,7 +240,7 @@ class ModelInstance:
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast( return cast(
int, list[int],
self._round_robin_invoke( self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,

View File

@ -1,7 +1,7 @@
import json import json
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional, cast from typing import Any, Optional, cast
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -350,7 +350,7 @@ class ProviderManager:
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid is True).all() providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() # noqa
provider_name_to_provider_records_dict = defaultdict(list) provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers: for provider in providers:
@ -369,7 +369,7 @@ class ProviderManager:
# Get all provider model records of the workspace # Get all provider model records of the workspace
provider_models = ( provider_models = (
db.session.query(ProviderModel) db.session.query(ProviderModel)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid is True) .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) # noqa
.all() .all()
) )
@ -739,9 +739,9 @@ class ProviderManager:
if not cached_provider_credentials: if not cached_provider_credentials:
try: try:
provider_credentials = json.loads(provider_record.encrypted_config) provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError: except JSONDecodeError:
provider_credentials = {} provider_credentials: dict[str, Any] = {}
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(
@ -758,7 +758,9 @@ class ProviderManager:
if variable in provider_credentials: if variable in provider_credentials:
try: try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa provider_credentials.get(variable, ""),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
) )
except ValueError: except ValueError:
pass pass

View File

@ -88,7 +88,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled is True, DocumentSegment.enabled == True, # noqa
DocumentSegment.index_node_id.in_(index_node_ids), DocumentSegment.index_node_id.in_(index_node_ids),
).all() ).all()
@ -109,8 +109,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter( document = Document.query.filter(
Document.id == segment.document_id, Document.id == segment.document_id,
Document.enabled is True, Document.enabled == True, # noqa
Document.archived is False, Document.archived == False, # noqa
).first() ).first()
if dataset and document: if dataset and document:
source = { source = {

View File

@ -7,7 +7,6 @@ from configs import dify_config
from core.entities.model_entities import ( from core.entities.model_entities import (
ModelWithProviderEntity, ModelWithProviderEntity,
ProviderModelWithStatusEntity, ProviderModelWithStatusEntity,
SimpleModelProviderEntity,
) )
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
@ -162,7 +161,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
Model with provider entity. Model with provider entity.
""" """
provider: SimpleModelProviderEntity provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None: def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
dump_model = model.model_dump() dump_model = model.model_dump()