fix: errors occrus during rebasing

This commit is contained in:
Yeuoly 2024-12-26 13:20:12 +08:00
parent 80d8e47e42
commit e231cf2c48
4 changed files with 14 additions and 13 deletions

View File

@ -191,7 +191,7 @@ class ModelInstance:
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
@ -240,7 +240,7 @@ class ModelInstance:
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
int,
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,

View File

@ -1,7 +1,7 @@
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Optional, cast
from typing import Any, Optional, cast
from sqlalchemy.exc import IntegrityError
@ -350,7 +350,7 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid is True).all()
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() # noqa
provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers:
@ -369,7 +369,7 @@ class ProviderManager:
# Get all provider model records of the workspace
provider_models = (
db.session.query(ProviderModel)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid is True)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) # noqa
.all()
)
@ -739,9 +739,9 @@ class ProviderManager:
if not cached_provider_credentials:
try:
provider_credentials = json.loads(provider_record.encrypted_config)
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
provider_credentials: dict[str, Any] = {}
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
@ -758,7 +758,9 @@ class ProviderManager:
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa
provider_credentials.get(variable, ""),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
except ValueError:
pass

View File

@ -88,7 +88,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled is True,
DocumentSegment.enabled == True, # noqa
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
@ -109,8 +109,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled is True,
Document.archived is False,
Document.enabled == True, # noqa
Document.archived == False, # noqa
).first()
if dataset and document:
source = {

View File

@ -7,7 +7,6 @@ from configs import dify_config
from core.entities.model_entities import (
ModelWithProviderEntity,
ProviderModelWithStatusEntity,
SimpleModelProviderEntity,
)
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject
@ -162,7 +161,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
Model with provider entity.
"""
provider: SimpleModelProviderEntity
provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
dump_model = model.model_dump()