chore: model.query change to db.session.query (#19551)

Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
非法操作 2025-05-13 09:13:12 +08:00 committed by GitHub
parent f1e7099541
commit 085bd1aa93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 74 additions and 32 deletions

View File

@ -552,11 +552,12 @@ def old_metadata_migration():
page = 1
while True:
try:
documents = (
DatasetDocument.query.filter(DatasetDocument.doc_metadata is not None)
stmt = (
select(DatasetDocument)
.filter(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc())
.paginate(page=page, per_page=50)
)
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound:
break
if not documents:

View File

@ -66,7 +66,7 @@ class InstalledAppsListApi(Resource):
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None:
raise NotFound("App not found")
@ -79,9 +79,11 @@ class InstalledAppsListApi(Resource):
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = InstalledApp.query.filter(
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
).first()
installed_app = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first()
)
if installed_app is None:
# todo: position

View File

@ -1,3 +1,5 @@
import logging
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -7,6 +9,8 @@ from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
_logger = logging.getLogger(__name__)
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
@ -42,9 +46,14 @@ class DatasetIndexToolCallbackHandler:
"""Handle tool end."""
for document in documents:
if document.metadata is not None:
dataset_document = DatasetDocument.query.filter(
DatasetDocument.id == document.metadata["document_id"]
).first()
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
document_id,
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)

View File

@ -660,10 +660,10 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedError()
document = DatasetDocument.query.filter_by(id=document_id).first()
document = db.session.query(DatasetDocument).filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedError()
@ -672,7 +672,7 @@ class IndexingRunner:
if extra_update_params:
update_params.update(extra_update_params)
DatasetDocument.query.filter_by(id=document_id).update(update_params)
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params)
db.session.commit()
@staticmethod

View File

@ -317,7 +317,7 @@ class NotionExtractor(BaseExtractor):
data_source_info["last_edited_time"] = last_edited_time
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params)
db.session.commit()
def get_notion_last_edited_time(self) -> str:

View File

@ -238,11 +238,15 @@ class DatasetRetrieval:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
document = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,
@ -506,9 +510,11 @@ class DatasetRetrieval:
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document = DatasetDocument.query.filter(
DatasetDocument.id == document.metadata["document_id"]
).first()
dataset_document = (
db.session.query(DatasetDocument)
.filter(DatasetDocument.id == document.metadata["document_id"])
.first()
)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (

View File

@ -186,11 +186,15 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,

View File

@ -1,4 +1,5 @@
import datetime
import logging
import time
import click
@ -20,6 +21,8 @@ from models.model import (
from models.web import SavedMessage
from services.feature_service import FeatureService
_logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
def clean_messages():
@ -46,7 +49,14 @@ def clean_messages():
break
for message in messages:
plan_sandbox_clean_message_day = message.created_at
app = App.query.filter_by(id=message.app_id).first()
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
_logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:

View File

@ -54,7 +54,7 @@ def mail_clean_document_notify_task():
)
if not current_owner_join:
continue
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first()
if not account:
continue

View File

@ -1,3 +1,4 @@
import logging
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
@ -12,6 +13,8 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
_logger = logging.getLogger(__name__)
class VectorService:
@classmethod
@ -22,7 +25,14 @@ class VectorService:
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
segment.document_id,
segment.id,
)
continue
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
@ -52,7 +62,7 @@ class VectorService:
raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
else:
document = Document(
document = Document( # type: ignore
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
@ -64,7 +74,7 @@ class VectorService:
documents.append(document)
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) # type: ignore
@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):