chore: all model.query replace to db.session.query (#19521)

This commit is contained in:
非法操作 2025-05-12 15:19:41 +08:00 committed by GitHub
parent b00f94df64
commit 14cd71ed0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 99 additions and 67 deletions

View File

@ -3,6 +3,7 @@ import logging
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
import services
@ -88,9 +89,8 @@ class WorkspaceListApi(Resource):
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(
page=args["page"], per_page=args["limit"], error_out=False
)
stmt = select(Tenant).order_by(Tenant.created_at.desc())
tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
has_more = False
if tenants.has_next:
@ -162,7 +162,7 @@ class CustomConfigWorkspaceApi(Resource):
parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"],
@ -226,7 +226,7 @@ class WorkspaceInfoApi(Resource):
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant.name = args["name"]
db.session.commit()

View File

@ -347,14 +347,18 @@ class NotionExtractor(BaseExtractor):
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise Exception(

View File

@ -61,13 +61,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
).first()
.first()
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
@ -97,13 +101,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
).first()
.first()
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
@ -121,14 +129,18 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str):
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
)
).first()
.first()
)
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)

View File

@ -45,7 +45,7 @@ def mail_clean_document_notify_task():
if plan != "sandbox":
knowledge_details = []
# check tenant
tenant = Tenant.query.filter(Tenant.id == tenant_id).first()
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
continue
# check current owner

View File

@ -300,9 +300,9 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
account_id=account.id, provider=provider
).first()
account_integrate: Optional[AccountIntegrate] = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
)
if account_integrate:
# If it exists, update the record
@ -851,7 +851,7 @@ class TenantService:
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()
tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict)

View File

@ -4,7 +4,7 @@ from typing import cast
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -124,8 +124,9 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
if keyword:
annotations = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
stmt = (
select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
@ -133,14 +134,14 @@ class AppAnnotationService:
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
else:
annotations = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
stmt = (
select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
return annotations.items, annotations.total
@classmethod
@ -325,13 +326,16 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
annotation_hit_histories = (
AppAnnotationHitHistory.query.filter(
stmt = (
select(AppAnnotationHitHistory)
.filter(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
annotation_hit_histories = db.paginate(
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
)
return annotation_hit_histories.items, annotation_hit_histories.total

View File

@ -1087,14 +1087,18 @@ class DocumentService:
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
@ -1302,14 +1306,18 @@ class DocumentService:
notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:

View File

@ -44,14 +44,18 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")