fix: replace all dataset.Model.query to db.session.query(Model) (#19509)

This commit is contained in:
非法操作 2025-05-12 13:52:33 +08:00 committed by GitHub
parent 49af07f444
commit b00f94df64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 430 additions and 265 deletions

View File

@ -6,6 +6,7 @@ from typing import Optional
import click import click
from flask import current_app from flask import current_app
from sqlalchemy import select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
@ -297,11 +298,11 @@ def migrate_knowledge_vector_database():
page = 1 page = 1
while True: while True:
try: try:
datasets = ( stmt = (
Dataset.query.filter(Dataset.indexing_technique == "high_quality") select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
) )
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound: except NotFound:
break break
@ -592,11 +593,15 @@ def old_metadata_migration():
) )
db.session.add(dataset_metadata_binding) db.session.add(dataset_metadata_binding)
else: else:
dataset_metadata_binding = DatasetMetadataBinding.query.filter( dataset_metadata_binding = (
DatasetMetadataBinding.dataset_id == document.dataset_id, db.session.query(DatasetMetadataBinding) # type: ignore
DatasetMetadataBinding.document_id == document.id, .filter(
DatasetMetadataBinding.metadata_id == dataset_metadata.id, DatasetMetadataBinding.dataset_id == document.dataset_id,
).first() DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
)
.first()
)
if not dataset_metadata_binding: if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding( dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=document.tenant_id, tenant_id=document.tenant_id,

View File

@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource):
) )
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document.id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document.id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields)) documents_status.append(marshal(document, document_status_fields))

View File

@ -6,7 +6,7 @@ from typing import cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource):
limits = DocumentService.DEFAULT_RULES["limits"] limits = DocumentService.DEFAULT_RULES["limits"]
if document_id: if document_id:
# get the latest process rule # get the latest process rule
document = Document.query.get_or_404(document_id) document = db.get_or_404(Document, document_id)
dataset = DatasetService.get_dataset(document.dataset_id) dataset = DatasetService.get_dataset(document.dataset_id)
@ -175,7 +175,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -209,18 +209,24 @@ class DatasetDocumentListApi(Resource):
desc(Document.position), desc(Document.position),
) )
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items
if fetch: if fetch:
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document.id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document.id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields) data = marshal(documents, document_with_segments_fields)
@ -563,14 +569,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document.id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document.id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:
@ -589,14 +601,20 @@ class DocumentIndexingStatusApi(DocumentResource):
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document_id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document_id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments

View File

@ -4,6 +4,7 @@ import pandas as pd
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal, reqparse from flask_restful import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -26,6 +27,7 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required from libs.login import login_required
@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource):
hit_count_gte = args["hit_count_gte"] hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"] keyword = args["keyword"]
query = DocumentSegment.query.filter( query = (
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id select(DocumentSegment)
).order_by(DocumentSegment.position.asc()) .filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.filter(DocumentSegment.status.in_(status_list))
@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false": elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False) query = query.filter(DocumentSegment.enabled == False)
segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
response = { response = {
"data": marshal(segments.items, segment_fields), "data": marshal(segments.items, segment_fields),
@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# check child chunk # check child chunk
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id db.session.query(ChildChunk)
).first() .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk: if not child_chunk:
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# check child chunk # check child chunk
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id db.session.query(ChildChunk)
).first() .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk: if not child_chunk:
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor

View File

@ -2,10 +2,10 @@ import json
from flask import request from flask import request
from flask_restful import marshal, reqparse from flask_restful import marshal, reqparse
from sqlalchemy import desc from sqlalchemy import desc, select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services.dataset_service import services
from controllers.common.errors import FilenameNotExistsError from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
@ -337,7 +337,7 @@ class DocumentListApi(DatasetApiResource):
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -345,7 +345,7 @@ class DocumentListApi(DatasetApiResource):
query = query.order_by(desc(Document.created_at), desc(Document.position)) query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items
response = { response = {
@ -374,14 +374,20 @@ class DocumentIndexingStatusApi(DatasetApiResource):
raise NotFound("Documents not found.") raise NotFound("Documents not found.")
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
DocumentSegment.completed_at.isnot(None), db.session.query(DocumentSegment)
DocumentSegment.document_id == str(document.id), .filter(
DocumentSegment.status != "re_segment", DocumentSegment.completed_at.isnot(None),
).count() DocumentSegment.document_id == str(document.id),
total_segments = DocumentSegment.query.filter( DocumentSegment.status != "re_segment",
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() .count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:

View File

@ -46,14 +46,22 @@ class DatasetIndexToolCallbackHandler:
DatasetDocument.id == document.metadata["document_id"] DatasetDocument.id == document.metadata["document_id"]
).first() ).first()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.index_node_id == document.metadata["doc_id"], db.session.query(ChildChunk)
ChildChunk.dataset_id == dataset_document.dataset_id, .filter(
ChildChunk.document_id == dataset_document.id, ChildChunk.index_node_id == document.metadata["doc_id"],
).first() ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk: if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( segment = (
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
) )
else: else:
query = db.session.query(DocumentSegment).filter( query = db.session.query(DocumentSegment).filter(

View File

@ -51,7 +51,7 @@ class IndexingRunner:
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
@ -103,15 +103,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is splitting."""
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by( document_segments = (
dataset_id=dataset.id, document_id=dataset_document.id db.session.query(DocumentSegment)
).all() .filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
for document_segment in document_segments: for document_segment in document_segments:
db.session.delete(document_segment) db.session.delete(document_segment)
@ -162,15 +164,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is indexing.""" """Run the indexing process when the index_status is indexing."""
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by( document_segments = (
dataset_id=dataset.id, document_id=dataset_document.id db.session.query(DocumentSegment)
).all() .filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
documents = [] documents = []
if document_segments: if document_segments:
@ -254,7 +258,7 @@ class IndexingRunner:
embedding_model_instance = None embedding_model_instance = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset not found.") raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
@ -587,7 +591,7 @@ class IndexingRunner:
@staticmethod @staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents): def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context(): with flask_app.app_context():
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
keyword = Keyword(dataset) keyword = Keyword(dataset)
@ -676,7 +680,7 @@ class IndexingRunner:
""" """
Update the document segment by document id. Update the document segment by document id.
""" """
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit() db.session.commit()
def _transform( def _transform(

View File

@ -237,7 +237,7 @@ class DatasetRetrieval:
if show_retrieve_source: if show_retrieve_source:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter( document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -511,14 +511,23 @@ class DatasetRetrieval:
).first() ).first()
if dataset_document: if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.index_node_id == document.metadata["doc_id"], db.session.query(ChildChunk)
ChildChunk.dataset_id == dataset_document.dataset_id, .filter(
ChildChunk.document_id == dataset_document.id, ChildChunk.index_node_id == document.metadata["doc_id"],
).first() ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk: if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( segment = (
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
) )
db.session.commit() db.session.commit()
else: else:

View File

@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = [] document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter( segments = (
DocumentSegment.dataset_id.in_(self.dataset_ids), db.session.query(DocumentSegment)
DocumentSegment.completed_at.isnot(None), .filter(
DocumentSegment.status == "completed", DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.enabled == True, DocumentSegment.completed_at.isnot(None),
DocumentSegment.index_node_id.in_(index_node_ids), DocumentSegment.status == "completed",
).all() DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
)
if segments: if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list = [] context_list = []
resource_number = 1 resource_number = 1
for segment in sorted_segments: for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = Document.query.filter( document = (
Document.id == segment.document_id, db.session.query(Document)
Document.enabled == True, .filter(
Document.archived == False, Document.id == segment.document_id,
).first() Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document: if dataset and document:
source = { source = {
"position": resource_number, "position": resource_number,

View File

@ -185,7 +185,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource: if self.return_resource:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter( document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,

View File

@ -275,12 +275,16 @@ class KnowledgeRetrievalNode(LLMNode):
if records: if records:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = Document.query.filter( document = (
Document.id == segment.document_id, db.session.query(Document)
Document.enabled == True, .filter(
Document.archived == False, Document.id == segment.document_id,
).first() Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document: if dataset and document:
source = { source = {
"metadata": { "metadata": {

View File

@ -93,7 +93,8 @@ class Dataset(Base):
@property @property
def latest_process_rule(self): def latest_process_rule(self):
return ( return (
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.first() .first()
) )
@ -138,7 +139,8 @@ class Dataset(Base):
@property @property
def word_count(self): def word_count(self):
return ( return (
Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count)))
.filter(Document.dataset_id == self.id) .filter(Document.dataset_id == self.id)
.scalar() .scalar()
) )
@ -440,12 +442,13 @@ class Document(Base):
@property @property
def segment_count(self): def segment_count(self):
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
@property @property
def hit_count(self): def hit_count(self):
return ( return (
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
.filter(DocumentSegment.document_id == self.id) .filter(DocumentSegment.document_id == self.id)
.scalar() .scalar()
) )
@ -892,7 +895,7 @@ class DatasetKeywordTable(Base):
return dct return dct
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=self.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
if not dataset: if not dataset:
return None return None
if self.data_source_type == "database": if self.data_source_type == "database":

View File

@ -2,7 +2,7 @@ import datetime
import time import time
import click import click
from sqlalchemy import func from sqlalchemy import func, select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import app import app
@ -51,8 +51,9 @@ def clean_unused_datasets_task():
) )
# Main query with join and filter # Main query with join and filter
datasets = ( stmt = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter( .filter(
Dataset.created_at < plan_sandbox_clean_day, Dataset.created_at < plan_sandbox_clean_day,
@ -60,9 +61,10 @@ def clean_unused_datasets_task():
func.coalesce(document_subquery_old.c.document_count, 0) > 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0,
) )
.order_by(Dataset.created_at.desc()) .order_by(Dataset.created_at.desc())
.paginate(page=1, per_page=50)
) )
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound: except NotFound:
break break
if datasets.items is None or len(datasets.items) == 0: if datasets.items is None or len(datasets.items) == 0:
@ -99,7 +101,7 @@ def clean_unused_datasets_task():
# update document # update document
update_params = {Document.enabled: False} update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit() db.session.commit()
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
except Exception as e: except Exception as e:
@ -135,8 +137,9 @@ def clean_unused_datasets_task():
) )
# Main query with join and filter # Main query with join and filter
datasets = ( stmt = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter( .filter(
Dataset.created_at < plan_pro_clean_day, Dataset.created_at < plan_pro_clean_day,
@ -144,8 +147,8 @@ def clean_unused_datasets_task():
func.coalesce(document_subquery_old.c.document_count, 0) > 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0,
) )
.order_by(Dataset.created_at.desc()) .order_by(Dataset.created_at.desc())
.paginate(page=1, per_page=50)
) )
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound: except NotFound:
break break
@ -175,7 +178,7 @@ def clean_unused_datasets_task():
# update document # update document
update_params = {Document.enabled: False} update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit() db.session.commit()
click.echo( click.echo(
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")

View File

@ -19,7 +19,9 @@ def create_tidb_serverless_task():
while True: while True:
try: try:
# check the number of idle tidb serverless # check the number of idle tidb serverless
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
)
if idle_tidb_serverless_number >= tidb_serverless_number: if idle_tidb_serverless_number >= tidb_serverless_number:
break break
# create tidb serverless # create tidb serverless

View File

@ -29,7 +29,9 @@ def mail_clean_document_notify_task():
# send document clean notify mail # send document clean notify mail
try: try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
)
# group by tenant_id # group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs: for dataset_auto_disable_log in dataset_auto_disable_logs:
@ -65,7 +67,7 @@ def mail_clean_document_notify_task():
) )
for dataset_id, document_ids in dataset_auto_dataset_map.items(): for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset: if dataset:
document_count = len(document_ids) document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")

View File

@ -5,6 +5,7 @@ import click
import app import app
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding from models.dataset import TidbAuthBinding
@ -14,9 +15,11 @@ def update_tidb_serverless_status_task():
start_at = time.perf_counter() start_at = time.perf_counter()
try: try:
# check the number of idle tidb serverless # check the number of idle tidb serverless
tidb_serverless_list = TidbAuthBinding.query.filter( tidb_serverless_list = (
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" db.session.query(TidbAuthBinding)
).all() .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
if len(tidb_serverless_list) == 0: if len(tidb_serverless_list) == 0:
return return
# update tidb serverless status # update tidb serverless status

View File

@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Optional from typing import Any, Optional
from flask_login import current_user from flask_login import current_user
from sqlalchemy import func from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService: class DatasetService:
@staticmethod @staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user: if user:
# get permitted dataset ids # get permitted dataset ids
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR: if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@ -129,7 +131,7 @@ class DatasetService:
else: else:
return [], 0 return [], 0
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
return datasets.items, datasets.total return datasets.items, datasets.total
@ -153,9 +155,10 @@ class DatasetService:
@staticmethod @staticmethod
def get_datasets_by_ids(ids, tenant_id): def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
) datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total return datasets.items, datasets.total
@staticmethod @staticmethod
@ -174,7 +177,7 @@ class DatasetService:
retrieval_model: Optional[RetrievalModel] = None, retrieval_model: Optional[RetrievalModel] = None,
): ):
# check if dataset name already exists # check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None embedding_model = None
if indexing_technique == "high_quality": if indexing_technique == "high_quality":
@ -235,7 +238,7 @@ class DatasetService:
@staticmethod @staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]: def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset return dataset
@staticmethod @staticmethod
@ -436,7 +439,7 @@ class DatasetService:
# update Retrieval model # update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"] filtered_data["retrieval_model"] = data["retrieval_model"]
dataset.query.filter_by(id=dataset_id).update(filtered_data) db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
db.session.commit() db.session.commit()
if action: if action:
@ -460,7 +463,7 @@ class DatasetService:
@staticmethod @staticmethod
def dataset_use_check(dataset_id) -> bool: def dataset_use_check(dataset_id) -> bool:
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count() count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0: if count > 0:
return True return True
return False return False
@ -475,7 +478,9 @@ class DatasetService:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members": if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
)
if ( if (
not user_permission not user_permission
and dataset.tenant_id != user.current_tenant_id and dataset.tenant_id != user.current_tenant_id
@ -499,23 +504,24 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any( if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
): ):
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod @staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int): def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = ( stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
DatasetQuery.query.filter_by(dataset_id=dataset_id)
.order_by(db.desc(DatasetQuery.created_at)) dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
)
return dataset_queries.items, dataset_queries.total return dataset_queries.items, dataset_queries.total
@staticmethod @staticmethod
def get_related_apps(dataset_id: str): def get_related_apps(dataset_id: str):
return ( return (
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at)) .order_by(db.desc(AppDatasetJoin.created_at))
.all() .all()
) )
@ -530,10 +536,14 @@ class DatasetService:
} }
# get recent 30 days auto disable logs # get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30) start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( dataset_auto_disable_logs = (
DatasetAutoDisableLog.dataset_id == dataset_id, db.session.query(DatasetAutoDisableLog)
DatasetAutoDisableLog.created_at >= start_date, .filter(
).all() DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
if dataset_auto_disable_logs: if dataset_auto_disable_logs:
return { return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs], "document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -873,7 +883,9 @@ class DocumentService:
@staticmethod @staticmethod
def get_documents_position(dataset_id): def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
)
if document: if document:
return document.position + 1 return document.position + 1
else: else:
@ -1010,13 +1022,17 @@ class DocumentService:
} }
# check duplicate # check duplicate
if knowledge_config.duplicate: if knowledge_config.duplicate:
document = Document.query.filter_by( document = (
dataset_id=dataset.id, db.session.query(Document)
tenant_id=current_user.current_tenant_id, .filter_by(
data_source_type="upload_file", dataset_id=dataset.id,
enabled=True, tenant_id=current_user.current_tenant_id,
name=file_name, data_source_type="upload_file",
).first() enabled=True,
name=file_name,
)
.first()
)
if document: if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@ -1054,12 +1070,16 @@ class DocumentService:
raise ValueError("No notion info list found.") raise ValueError("No notion info list found.")
exist_page_ids = [] exist_page_ids = []
exist_document = {} exist_document = {}
documents = Document.query.filter_by( documents = (
dataset_id=dataset.id, db.session.query(Document)
tenant_id=current_user.current_tenant_id, .filter_by(
data_source_type="notion_import", dataset_id=dataset.id,
enabled=True, tenant_id=current_user.current_tenant_id,
).all() data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents: if documents:
for document in documents: for document in documents:
data_source_info = json.loads(document.data_source_info) data_source_info = json.loads(document.data_source_info)
@ -1206,12 +1226,16 @@ class DocumentService:
@staticmethod @staticmethod
def get_tenant_documents_count(): def get_tenant_documents_count():
documents_count = Document.query.filter( documents_count = (
Document.completed_at.isnot(None), db.session.query(Document)
Document.enabled == True, .filter(
Document.archived == False, Document.completed_at.isnot(None),
Document.tenant_id == current_user.current_tenant_id, Document.enabled == True,
).count() Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
)
.count()
)
return documents_count return documents_count
@staticmethod @staticmethod
@ -1328,7 +1352,7 @@ class DocumentService:
db.session.commit() db.session.commit()
# update document segment # update document segment
update_params = {DocumentSegment.status: "re_segment"} update_params = {DocumentSegment.status: "re_segment"}
DocumentSegment.query.filter_by(document_id=document.id).update(update_params) db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
db.session.commit() db.session.commit()
# trigger async task # trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id) document_indexing_update_task.delay(document.dataset_id, document.id)
@ -1918,7 +1942,8 @@ class SegmentService:
@classmethod @classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
index_node_ids = ( index_node_ids = (
DocumentSegment.query.with_entities(DocumentSegment.index_node_id) db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter( .filter(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
@ -2157,20 +2182,28 @@ class SegmentService:
def get_child_chunks( def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
): ):
query = ChildChunk.query.filter_by( query = (
tenant_id=current_user.current_tenant_id, select(ChildChunk)
dataset_id=dataset_id, .filter_by(
document_id=document_id, tenant_id=current_user.current_tenant_id,
segment_id=segment_id, dataset_id=dataset_id,
).order_by(ChildChunk.position.asc()) document_id=document_id,
segment_id=segment_id,
)
.order_by(ChildChunk.position.asc())
)
if keyword: if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod @classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
"""Get a child chunk by its ID.""" """Get a child chunk by its ID."""
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first() result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None return result if isinstance(result, ChildChunk) else None
@classmethod @classmethod
@ -2184,7 +2217,7 @@ class SegmentService:
limit: int = 20, limit: int = 20,
): ):
"""Get segments for a document with optional filtering.""" """Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter( query = select(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
) )
@ -2194,9 +2227,8 @@ class SegmentService:
if keyword: if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate( query = query.order_by(DocumentSegment.position.asc())
page=page, per_page=limit, max_per_page=100, error_out=False paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
)
return paginated_segments.items, paginated_segments.total return paginated_segments.items, paginated_segments.total
@ -2236,9 +2268,11 @@ class SegmentService:
raise ValueError(ex.description) raise ValueError(ex.description)
# check segment # check segment
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -2251,9 +2285,11 @@ class SegmentService:
@classmethod @classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID.""" """Get a segment by its ID."""
result = DocumentSegment.query.filter( result = (
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None return result if isinstance(result, DocumentSegment) else None

View File

@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from sqlalchemy import select
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService: class ExternalDatasetService:
@staticmethod @staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: def get_external_knowledge_apis(
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by( page, per_page, tenant_id, search=None
ExternalKnowledgeApis.created_at.desc() ) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = (
select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
) )
if search: if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
)
return external_knowledge_apis.items, external_knowledge_apis.total return external_knowledge_apis.items, external_knowledge_apis.total
@ -92,18 +99,18 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( external_knowledge_api: Optional[ExternalKnowledgeApis] = (
id=external_knowledge_api_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
return external_knowledge_api return external_knowledge_api
@staticmethod @staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( external_knowledge_api: Optional[ExternalKnowledgeApis] = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
@ -120,9 +127,9 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
@ -131,25 +138,29 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count() count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
)
if count > 0: if count > 0:
return True, count return True, count
return False, 0 return False, 0
@staticmethod @staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
dataset_id=dataset_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
).first() )
if not external_knowledge_binding: if not external_knowledge_binding:
raise ValueError("external knowledge binding not found") raise ValueError("external knowledge binding not found")
return external_knowledge_binding return external_knowledge_binding
@staticmethod @staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings) settings = json.loads(external_knowledge_api.settings)
@ -212,11 +223,13 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists # check if dataset name already exists
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first(): if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id db.session.query(ExternalKnowledgeApis)
).first() .filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
)
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
@ -254,15 +267,17 @@ class ExternalDatasetService:
external_retrieval_parameters: dict, external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None, metadata_condition: Optional[MetadataCondition] = None,
) -> list: ) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( external_knowledge_binding = (
dataset_id=dataset_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
).first() )
if not external_knowledge_binding: if not external_knowledge_binding:
raise ValueError("external knowledge binding not found") raise ValueError("external knowledge binding not found")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_binding.external_knowledge_api_id db.session.query(ExternalKnowledgeApis)
).first() .filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
)
if not external_knowledge_api: if not external_knowledge_api:
raise ValueError("external api template not found") raise ValueError("external api template not found")

View File

@ -20,9 +20,11 @@ class MetadataService:
@staticmethod @staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name already exists # check if metadata name already exists
if DatasetMetadata.query.filter_by( if (
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name db.session.query(DatasetMetadata)
).first(): .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
):
raise ValueError("Metadata name already exists.") raise ValueError("Metadata name already exists.")
for field in BuiltInField: for field in BuiltInField:
if field.value == metadata_args.name: if field.value == metadata_args.name:
@ -42,16 +44,18 @@ class MetadataService:
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
lock_key = f"dataset_metadata_lock_{dataset_id}" lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists # check if metadata name already exists
if DatasetMetadata.query.filter_by( if (
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name db.session.query(DatasetMetadata)
).first(): .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
.first()
):
raise ValueError("Metadata name already exists.") raise ValueError("Metadata name already exists.")
for field in BuiltInField: for field in BuiltInField:
if field.value == name: if field.value == name:
raise ValueError("Metadata name already exists in Built-in fields.") raise ValueError("Metadata name already exists in Built-in fields.")
try: try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None: if metadata is None:
raise ValueError("Metadata not found.") raise ValueError("Metadata not found.")
old_name = metadata.name old_name = metadata.name
@ -60,7 +64,9 @@ class MetadataService:
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update related documents # update related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings: if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings] document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids) documents = DocumentService.get_document_by_ids(document_ids)
@ -82,13 +88,15 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}" lock_key = f"dataset_metadata_lock_{dataset_id}"
try: try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None: if metadata is None:
raise ValueError("Metadata not found.") raise ValueError("Metadata not found.")
db.session.delete(metadata) db.session.delete(metadata)
# deal related documents # deal related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings: if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings] document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids) documents = DocumentService.get_document_by_ids(document_ids)
@ -193,7 +201,7 @@ class MetadataService:
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
# deal metadata binding # deal metadata binding
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
for metadata_value in operation.metadata_list: for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding( dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -230,9 +238,9 @@ class MetadataService:
"id": item.get("id"), "id": item.get("id"),
"name": item.get("name"), "name": item.get("name"),
"type": item.get("type"), "type": item.get("type"),
"count": DatasetMetadataBinding.query.filter_by( "count": db.session.query(DatasetMetadataBinding)
metadata_id=item.get("id"), dataset_id=dataset.id .filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
).count(), .count(),
} }
for item in dataset.doc_metadata or [] for item in dataset.doc_metadata or []
if item.get("id") != "built-in" if item.get("id") != "built-in"

View File

@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
db.session.commit() db.session.commit()
document = Document( document = Document(
page_content=segment.content, page_content=segment.content,
@ -78,7 +78,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
db.session.commit() db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()

View File

@ -24,7 +24,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
start_at = time.perf_counter() start_at = time.perf_counter()
try: try:
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise Exception("Dataset not found") raise Exception("Dataset not found")