chore: all model.query replace to db.session.query (#19521)

This commit is contained in:
非法操作 2025-05-12 15:19:41 +08:00 committed by GitHub
parent b00f94df64
commit 14cd71ed0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 99 additions and 67 deletions

View File

@ -3,6 +3,7 @@ import logging
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
import services import services
@ -88,9 +89,8 @@ class WorkspaceListApi(Resource):
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate( stmt = select(Tenant).order_by(Tenant.created_at.desc())
page=args["page"], per_page=args["limit"], error_out=False tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
)
has_more = False has_more = False
if tenants.has_next: if tenants.has_next:
@ -162,7 +162,7 @@ class CustomConfigWorkspaceApi(Resource):
parser.add_argument("replace_webapp_logo", type=str, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
custom_config_dict = { custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"], "remove_webapp_brand": args["remove_webapp_brand"],
@ -226,7 +226,7 @@ class WorkspaceInfoApi(Resource):
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant.name = args["name"] tenant.name = args["name"]
db.session.commit() db.session.commit()

View File

@ -347,14 +347,18 @@ class NotionExtractor(BaseExtractor):
@classmethod @classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.and_( db.session.query(DataSourceOauthBinding)
DataSourceOauthBinding.tenant_id == tenant_id, .filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
) )
).first() .first()
)
if not data_source_binding: if not data_source_binding:
raise Exception( raise Exception(

View File

@ -61,13 +61,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.and_( db.session.query(DataSourceOauthBinding)
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, .filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
) )
).first() .first()
)
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -97,13 +101,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.and_( db.session.query(DataSourceOauthBinding)
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, .filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
) )
).first() .first()
)
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -121,14 +129,18 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str): def sync_data_source(self, binding_id: str):
# save data source binding # save data source binding
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.and_( db.session.query(DataSourceOauthBinding)
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, .filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
) )
).first() .first()
)
if data_source_binding: if data_source_binding:
# get all authorized pages # get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token) pages = self.get_authorized_pages(data_source_binding.access_token)

View File

@ -45,7 +45,7 @@ def mail_clean_document_notify_task():
if plan != "sandbox": if plan != "sandbox":
knowledge_details = [] knowledge_details = []
# check tenant # check tenant
tenant = Tenant.query.filter(Tenant.id == tenant_id).first() tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant: if not tenant:
continue continue
# check current owner # check current owner

View File

@ -300,9 +300,9 @@ class AccountService:
"""Link account integrate""" """Link account integrate"""
try: try:
# Query whether there is an existing binding record for the same provider # Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( account_integrate: Optional[AccountIntegrate] = (
account_id=account.id, provider=provider db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
).first() )
if account_integrate: if account_integrate:
# If it exists, update the record # If it exists, update the record
@ -851,7 +851,7 @@ class TenantService:
@staticmethod @staticmethod
def get_custom_config(tenant_id: str) -> dict: def get_custom_config(tenant_id: str) -> dict:
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict) return cast(dict, tenant.custom_config_dict)

View File

@ -4,7 +4,7 @@ from typing import cast
import pandas as pd import pandas as pd
from flask_login import current_user from flask_login import current_user
from sqlalchemy import or_ from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -124,8 +124,9 @@ class AppAnnotationService:
if not app: if not app:
raise NotFound("App not found") raise NotFound("App not found")
if keyword: if keyword:
annotations = ( stmt = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter( .filter(
or_( or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)), MessageAnnotation.question.ilike("%{}%".format(keyword)),
@ -133,14 +134,14 @@ class AppAnnotationService:
) )
) )
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
) )
else: else:
annotations = ( stmt = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
) )
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
return annotations.items, annotations.total return annotations.items, annotations.total
@classmethod @classmethod
@ -325,13 +326,16 @@ class AppAnnotationService:
if not annotation: if not annotation:
raise NotFound("Annotation not found") raise NotFound("Annotation not found")
annotation_hit_histories = ( stmt = (
AppAnnotationHitHistory.query.filter( select(AppAnnotationHitHistory)
.filter(
AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id, AppAnnotationHitHistory.annotation_id == annotation_id,
) )
.order_by(AppAnnotationHitHistory.created_at.desc()) .order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) )
annotation_hit_histories = db.paginate(
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
) )
return annotation_hit_histories.items, annotation_hit_histories.total return annotation_hit_histories.items, annotation_hit_histories.total

View File

@ -1087,14 +1087,18 @@ class DocumentService:
exist_document[data_source_info["notion_page_id"]] = document.id exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.and_( db.session.query(DataSourceOauthBinding)
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, .filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
) )
).first() .first()
)
if not data_source_binding: if not data_source_binding:
raise ValueError("Data source binding not found.") raise ValueError("Data source binding not found.")
for page in notion_info.pages: for page in notion_info.pages:
@ -1302,14 +1306,18 @@ class DocumentService:
notion_info_list = document_data.data_source.info_list.notion_info_list notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.and_( db.session.query(DataSourceOauthBinding)
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, .filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
) )
).first() .first()
)
if not data_source_binding: if not data_source_binding:
raise ValueError("Data source binding not found.") raise ValueError("Data source binding not found.")
for page in notion_info.pages: for page in notion_info.pages:

View File

@ -44,14 +44,18 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info["notion_page_id"] page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"] page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"] page_edited_time = data_source_info["last_edited_time"]
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.and_( db.session.query(DataSourceOauthBinding)
DataSourceOauthBinding.tenant_id == document.tenant_id, .filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
) )
).first() .first()
)
if not data_source_binding: if not data_source_binding:
raise ValueError("Data source binding not found.") raise ValueError("Data source binding not found.")