fix: replace all dataset.Model.query to db.session.query(Model) (#19509)

This commit is contained in:
非法操作 2025-05-12 13:52:33 +08:00 committed by GitHub
parent 49af07f444
commit b00f94df64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 430 additions and 265 deletions

View File

@ -6,6 +6,7 @@ from typing import Optional
import click
from flask import current_app
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
@ -297,11 +298,11 @@ def migrate_knowledge_vector_database():
page = 1
while True:
try:
datasets = (
Dataset.query.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
stmt = (
select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound:
break
@ -592,11 +593,15 @@ def old_metadata_migration():
)
db.session.add(dataset_metadata_binding)
else:
dataset_metadata_binding = DatasetMetadataBinding.query.filter(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
).first()
dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
.filter(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
)
.first()
)
if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=document.tenant_id,

View File

@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource):
)
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields))

View File

@ -6,7 +6,7 @@ from typing import cast
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource):
limits = DocumentService.DEFAULT_RULES["limits"]
if document_id:
# get the latest process rule
document = Document.query.get_or_404(document_id)
document = db.get_or_404(Document, document_id)
dataset = DatasetService.get_dataset(document.dataset_id)
@ -175,7 +175,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
if search:
search = f"%{search}%"
@ -209,18 +209,24 @@ class DatasetDocumentListApi(Resource):
desc(Document.position),
)
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
if fetch:
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
@ -563,14 +569,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch)
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
@ -589,14 +601,20 @@ class DocumentIndexingStatusApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments

View File

@ -4,6 +4,7 @@ import pandas as pd
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -26,6 +27,7 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource):
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).order_by(DocumentSegment.position.asc())
query = (
select(DocumentSegment)
.filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False)
segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
response = {
"data": marshal(segments.items, segment_fields),
@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_dataset_editor:
@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor

View File

@ -2,10 +2,10 @@ import json
from flask import request
from flask_restful import marshal, reqparse
from sqlalchemy import desc
from sqlalchemy import desc, select
from werkzeug.exceptions import NotFound
import services.dataset_service
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
@ -337,7 +337,7 @@ class DocumentListApi(DatasetApiResource):
if not dataset:
raise NotFound("Dataset not found.")
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if search:
search = f"%{search}%"
@ -345,7 +345,7 @@ class DocumentListApi(DatasetApiResource):
query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
response = {
@ -374,14 +374,20 @@ class DocumentIndexingStatusApi(DatasetApiResource):
raise NotFound("Documents not found.")
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:

View File

@ -46,14 +46,22 @@ class DatasetIndexToolCallbackHandler:
DatasetDocument.id == document.metadata["document_id"]
).first()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
)
else:
query = db.session.query(DocumentSegment).filter(

View File

@ -51,7 +51,7 @@ class IndexingRunner:
for dataset_document in dataset_documents:
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -103,15 +103,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is splitting."""
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id, document_id=dataset_document.id
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
for document_segment in document_segments:
db.session.delete(document_segment)
@ -162,15 +164,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is indexing."""
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id, document_id=dataset_document.id
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
documents = []
if document_segments:
@ -254,7 +258,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
@ -587,7 +591,7 @@ class IndexingRunner:
@staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context():
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
@ -676,7 +680,7 @@ class IndexingRunner:
"""
Update the document segment by document id.
"""
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
def _transform(

View File

@ -237,7 +237,7 @@ class DatasetRetrieval:
if show_retrieve_source:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
@ -511,14 +511,23 @@ class DatasetRetrieval:
).first()
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
db.session.commit()
else:

View File

@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
)
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document:
source = {
"position": resource_number,

View File

@ -185,7 +185,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,

View File

@ -275,12 +275,16 @@ class KnowledgeRetrievalNode(LLMNode):
if records:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document:
source = {
"metadata": {

View File

@ -93,7 +93,8 @@ class Dataset(Base):
@property
def latest_process_rule(self):
return (
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
)
@ -138,7 +139,8 @@ class Dataset(Base):
@property
def word_count(self):
return (
Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count)))
.filter(Document.dataset_id == self.id)
.scalar()
)
@ -440,12 +442,13 @@ class Document(Base):
@property
def segment_count(self):
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
@property
def hit_count(self):
return (
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
.filter(DocumentSegment.document_id == self.id)
.scalar()
)
@ -892,7 +895,7 @@ class DatasetKeywordTable(Base):
return dct
# get dataset
dataset = Dataset.query.filter_by(id=self.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
if not dataset:
return None
if self.data_source_type == "database":

View File

@ -2,7 +2,7 @@ import datetime
import time
import click
from sqlalchemy import func
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
import app
@ -51,8 +51,9 @@ def clean_unused_datasets_task():
)
# Main query with join and filter
datasets = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
Dataset.created_at < plan_sandbox_clean_day,
@ -60,9 +61,10 @@ def clean_unused_datasets_task():
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
.paginate(page=1, per_page=50)
)
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound:
break
if datasets.items is None or len(datasets.items) == 0:
@ -99,7 +101,7 @@ def clean_unused_datasets_task():
# update document
update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
except Exception as e:
@ -135,8 +137,9 @@ def clean_unused_datasets_task():
)
# Main query with join and filter
datasets = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
Dataset.created_at < plan_pro_clean_day,
@ -144,8 +147,8 @@ def clean_unused_datasets_task():
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
.paginate(page=1, per_page=50)
)
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound:
break
@ -175,7 +178,7 @@ def clean_unused_datasets_task():
# update document
update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")

View File

@ -19,7 +19,9 @@ def create_tidb_serverless_task():
while True:
try:
# check the number of idle tidb serverless
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count()
idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
)
if idle_tidb_serverless_number >= tidb_serverless_number:
break
# create tidb serverless

View File

@ -29,7 +29,9 @@ def mail_clean_document_notify_task():
# send document clean notify mail
try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
)
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
@ -65,7 +67,7 @@ def mail_clean_document_notify_task():
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")

View File

@ -5,6 +5,7 @@ import click
import app
from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
@ -14,9 +15,11 @@ def update_tidb_serverless_status_task():
start_at = time.perf_counter()
try:
# check the number of idle tidb serverless
tidb_serverless_list = TidbAuthBinding.query.filter(
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING"
).all()
tidb_serverless_list = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
if len(tidb_serverless_list) == 0:
return
# update tidb serverless status

View File

@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all()
dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@ -129,7 +131,7 @@ class DatasetService:
else:
return [], 0
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
return datasets.items, datasets.total
@ -153,9 +155,10 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
)
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total
@staticmethod
@ -174,7 +177,7 @@ class DatasetService:
retrieval_model: Optional[RetrievalModel] = None,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == "high_quality":
@ -235,7 +238,7 @@ class DatasetService:
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
@ -436,7 +439,7 @@ class DatasetService:
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
if action:
@ -460,7 +463,7 @@ class DatasetService:
@staticmethod
def dataset_use_check(dataset_id) -> bool:
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count()
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0:
return True
return False
@ -475,7 +478,9 @@ class DatasetService:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
)
if (
not user_permission
and dataset.tenant_id != user.current_tenant_id
@ -499,23 +504,24 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = (
DatasetQuery.query.filter_by(dataset_id=dataset_id)
.order_by(db.desc(DatasetQuery.created_at))
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
)
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
return dataset_queries.items, dataset_queries.total
@staticmethod
def get_related_apps(dataset_id: str):
return (
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id)
db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
@ -530,10 +536,14 @@ class DatasetService:
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
).all()
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.filter(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
if dataset_auto_disable_logs:
return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -873,7 +883,9 @@ class DocumentService:
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
)
if document:
return document.position + 1
else:
@ -1010,13 +1022,17 @@ class DocumentService:
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@ -1054,12 +1070,16 @@ class DocumentService:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
@ -1206,12 +1226,16 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
documents_count = Document.query.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
).count()
documents_count = (
db.session.query(Document)
.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
)
.count()
)
return documents_count
@staticmethod
@ -1328,7 +1352,7 @@ class DocumentService:
db.session.commit()
# update document segment
update_params = {DocumentSegment.status: "re_segment"}
DocumentSegment.query.filter_by(document_id=document.id).update(update_params)
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
@ -1918,7 +1942,8 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
index_node_ids = (
DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
@ -2157,20 +2182,28 @@ class SegmentService:
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
query = ChildChunk.query.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
).order_by(ChildChunk.position.asc())
query = (
select(ChildChunk)
.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
)
.order_by(ChildChunk.position.asc())
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
"""Get a child chunk by its ID."""
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first()
result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None
@classmethod
@ -2184,7 +2217,7 @@ class SegmentService:
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter(
query = select(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
)
@ -2194,9 +2227,8 @@ class SegmentService:
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
page=page, per_page=limit, max_per_page=100, error_out=False
)
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total
@ -2236,9 +2268,11 @@ class SegmentService:
raise ValueError(ex.description)
# check segment
segment = DocumentSegment.query.filter(
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
@ -2251,9 +2285,11 @@ class SegmentService:
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID."""
result = DocumentSegment.query.filter(
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id
).first()
result = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None

View File

@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
import httpx
from sqlalchemy import select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService:
@staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by(
ExternalKnowledgeApis.created_at.desc()
def get_external_knowledge_apis(
page, per_page, tenant_id, search=None
) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = (
select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
)
return external_knowledge_apis.items, external_knowledge_apis.total
@ -92,18 +99,18 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id
).first()
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
return external_knowledge_api
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
@ -120,9 +127,9 @@ class ExternalDatasetService:
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -131,25 +138,29 @@ class ExternalDatasetService:
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count()
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
)
if count > 0:
return True, count
return False, 0
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings)
@ -212,11 +223,13 @@ class ExternalDatasetService:
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first():
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -254,15 +267,17 @@ class ExternalDatasetService:
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_binding.external_knowledge_api_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
)
if not external_knowledge_api:
raise ValueError("external api template not found")

View File

@ -20,9 +20,11 @@ class MetadataService:
@staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name already exists
if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
).first():
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == metadata_args.name:
@ -42,16 +44,18 @@ class MetadataService:
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
).first():
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
.first()
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == name:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
@ -60,7 +64,9 @@ class MetadataService:
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -82,13 +88,15 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# deal related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -193,7 +201,7 @@ class MetadataService:
db.session.add(document)
db.session.commit()
# deal metadata binding
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.current_tenant_id,
@ -230,9 +238,9 @@ class MetadataService:
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": DatasetMetadataBinding.query.filter_by(
metadata_id=item.get("id"), dataset_id=dataset.id
).count(),
"count": db.session.query(DatasetMetadataBinding)
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
.count(),
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"

View File

@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
db.session.commit()
document = Document(
page_content=segment.content,
@ -78,7 +78,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
db.session.commit()
end_at = time.perf_counter()

View File

@ -24,7 +24,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
start_at = time.perf_counter()
try:
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")