mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-01 16:03:38 +08:00
fix: replace all dataset.Model.query to db.session.query(Model) (#19509)
This commit is contained in:
parent
49af07f444
commit
b00f94df64
@ -6,6 +6,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -297,11 +298,11 @@ def migrate_knowledge_vector_database():
|
|||||||
page = 1
|
page = 1
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
datasets = (
|
stmt = (
|
||||||
Dataset.query.filter(Dataset.indexing_technique == "high_quality")
|
select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
|
||||||
.order_by(Dataset.created_at.desc())
|
|
||||||
.paginate(page=page, per_page=50)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||||
except NotFound:
|
except NotFound:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -592,11 +593,15 @@ def old_metadata_migration():
|
|||||||
)
|
)
|
||||||
db.session.add(dataset_metadata_binding)
|
db.session.add(dataset_metadata_binding)
|
||||||
else:
|
else:
|
||||||
dataset_metadata_binding = DatasetMetadataBinding.query.filter(
|
dataset_metadata_binding = (
|
||||||
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
db.session.query(DatasetMetadataBinding) # type: ignore
|
||||||
DatasetMetadataBinding.document_id == document.id,
|
.filter(
|
||||||
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
||||||
).first()
|
DatasetMetadataBinding.document_id == document.id,
|
||||||
|
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not dataset_metadata_binding:
|
if not dataset_metadata_binding:
|
||||||
dataset_metadata_binding = DatasetMetadataBinding(
|
dataset_metadata_binding = DatasetMetadataBinding(
|
||||||
tenant_id=document.tenant_id,
|
tenant_id=document.tenant_id,
|
||||||
|
@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource):
|
|||||||
)
|
)
|
||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = DocumentSegment.query.filter(
|
completed_segments = (
|
||||||
DocumentSegment.completed_at.isnot(None),
|
db.session.query(DocumentSegment)
|
||||||
DocumentSegment.document_id == str(document.id),
|
.filter(
|
||||||
DocumentSegment.status != "re_segment",
|
DocumentSegment.completed_at.isnot(None),
|
||||||
).count()
|
DocumentSegment.document_id == str(document.id),
|
||||||
total_segments = DocumentSegment.query.filter(
|
DocumentSegment.status != "re_segment",
|
||||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
)
|
||||||
).count()
|
.count()
|
||||||
|
)
|
||||||
|
total_segments = (
|
||||||
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||||
|
.count()
|
||||||
|
)
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
documents_status.append(marshal(document, document_status_fields))
|
documents_status.append(marshal(document, document_status_fields))
|
||||||
|
@ -6,7 +6,7 @@ from typing import cast
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import asc, desc
|
from sqlalchemy import asc, desc, select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource):
|
|||||||
limits = DocumentService.DEFAULT_RULES["limits"]
|
limits = DocumentService.DEFAULT_RULES["limits"]
|
||||||
if document_id:
|
if document_id:
|
||||||
# get the latest process rule
|
# get the latest process rule
|
||||||
document = Document.query.get_or_404(document_id)
|
document = db.get_or_404(Document, document_id)
|
||||||
|
|
||||||
dataset = DatasetService.get_dataset(document.dataset_id)
|
dataset = DatasetService.get_dataset(document.dataset_id)
|
||||||
|
|
||||||
@ -175,7 +175,7 @@ class DatasetDocumentListApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search = f"%{search}%"
|
search = f"%{search}%"
|
||||||
@ -209,18 +209,24 @@ class DatasetDocumentListApi(Resource):
|
|||||||
desc(Document.position),
|
desc(Document.position),
|
||||||
)
|
)
|
||||||
|
|
||||||
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
documents = paginated_documents.items
|
documents = paginated_documents.items
|
||||||
if fetch:
|
if fetch:
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = DocumentSegment.query.filter(
|
completed_segments = (
|
||||||
DocumentSegment.completed_at.isnot(None),
|
db.session.query(DocumentSegment)
|
||||||
DocumentSegment.document_id == str(document.id),
|
.filter(
|
||||||
DocumentSegment.status != "re_segment",
|
DocumentSegment.completed_at.isnot(None),
|
||||||
).count()
|
DocumentSegment.document_id == str(document.id),
|
||||||
total_segments = DocumentSegment.query.filter(
|
DocumentSegment.status != "re_segment",
|
||||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
)
|
||||||
).count()
|
.count()
|
||||||
|
)
|
||||||
|
total_segments = (
|
||||||
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||||
|
.count()
|
||||||
|
)
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
data = marshal(documents, document_with_segments_fields)
|
data = marshal(documents, document_with_segments_fields)
|
||||||
@ -563,14 +569,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||||||
documents = self.get_batch_documents(dataset_id, batch)
|
documents = self.get_batch_documents(dataset_id, batch)
|
||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = DocumentSegment.query.filter(
|
completed_segments = (
|
||||||
DocumentSegment.completed_at.isnot(None),
|
db.session.query(DocumentSegment)
|
||||||
DocumentSegment.document_id == str(document.id),
|
.filter(
|
||||||
DocumentSegment.status != "re_segment",
|
DocumentSegment.completed_at.isnot(None),
|
||||||
).count()
|
DocumentSegment.document_id == str(document.id),
|
||||||
total_segments = DocumentSegment.query.filter(
|
DocumentSegment.status != "re_segment",
|
||||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
)
|
||||||
).count()
|
.count()
|
||||||
|
)
|
||||||
|
total_segments = (
|
||||||
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||||
|
.count()
|
||||||
|
)
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
if document.is_paused:
|
if document.is_paused:
|
||||||
@ -589,14 +601,20 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
|
|
||||||
completed_segments = DocumentSegment.query.filter(
|
completed_segments = (
|
||||||
DocumentSegment.completed_at.isnot(None),
|
db.session.query(DocumentSegment)
|
||||||
DocumentSegment.document_id == str(document_id),
|
.filter(
|
||||||
DocumentSegment.status != "re_segment",
|
DocumentSegment.completed_at.isnot(None),
|
||||||
).count()
|
DocumentSegment.document_id == str(document_id),
|
||||||
total_segments = DocumentSegment.query.filter(
|
DocumentSegment.status != "re_segment",
|
||||||
DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment"
|
)
|
||||||
).count()
|
.count()
|
||||||
|
)
|
||||||
|
total_segments = (
|
||||||
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
|
@ -4,6 +4,7 @@ import pandas as pd
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal, reqparse
|
from flask_restful import Resource, marshal, reqparse
|
||||||
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -26,6 +27,7 @@ from controllers.console.wraps import (
|
|||||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
hit_count_gte = args["hit_count_gte"]
|
hit_count_gte = args["hit_count_gte"]
|
||||||
keyword = args["keyword"]
|
keyword = args["keyword"]
|
||||||
|
|
||||||
query = DocumentSegment.query.filter(
|
query = (
|
||||||
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
select(DocumentSegment)
|
||||||
).order_by(DocumentSegment.position.asc())
|
.filter(
|
||||||
|
DocumentSegment.document_id == str(document_id),
|
||||||
|
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||||
|
)
|
||||||
|
.order_by(DocumentSegment.position.asc())
|
||||||
|
)
|
||||||
|
|
||||||
if status_list:
|
if status_list:
|
||||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||||
@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
elif args["enabled"].lower() == "false":
|
elif args["enabled"].lower() == "false":
|
||||||
query = query.filter(DocumentSegment.enabled == False)
|
query = query.filter(DocumentSegment.enabled == False)
|
||||||
|
|
||||||
segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"data": marshal(segments.items, segment_fields),
|
"data": marshal(segments.items, segment_fields),
|
||||||
@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = (
|
||||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||||
@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = (
|
||||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||||
@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = (
|
||||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = (
|
||||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = (
|
||||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||||
@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = (
|
||||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
# check child chunk
|
# check child chunk
|
||||||
child_chunk_id = str(child_chunk_id)
|
child_chunk_id = str(child_chunk_id)
|
||||||
child_chunk = ChildChunk.query.filter(
|
child_chunk = (
|
||||||
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
|
db.session.query(ChildChunk)
|
||||||
).first()
|
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not child_chunk:
|
if not child_chunk:
|
||||||
raise NotFound("Child chunk not found.")
|
raise NotFound("Child chunk not found.")
|
||||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||||
@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
# check segment
|
# check segment
|
||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = DocumentSegment.query.filter(
|
segment = (
|
||||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
# check child chunk
|
# check child chunk
|
||||||
child_chunk_id = str(child_chunk_id)
|
child_chunk_id = str(child_chunk_id)
|
||||||
child_chunk = ChildChunk.query.filter(
|
child_chunk = (
|
||||||
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
|
db.session.query(ChildChunk)
|
||||||
).first()
|
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not child_chunk:
|
if not child_chunk:
|
||||||
raise NotFound("Child chunk not found.")
|
raise NotFound("Child chunk not found.")
|
||||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||||
|
@ -2,10 +2,10 @@ import json
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restful import marshal, reqparse
|
from flask_restful import marshal, reqparse
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc, select
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
import services.dataset_service
|
import services
|
||||||
from controllers.common.errors import FilenameNotExistsError
|
from controllers.common.errors import FilenameNotExistsError
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.app.error import (
|
from controllers.service_api.app.error import (
|
||||||
@ -337,7 +337,7 @@ class DocumentListApi(DatasetApiResource):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search = f"%{search}%"
|
search = f"%{search}%"
|
||||||
@ -345,7 +345,7 @@ class DocumentListApi(DatasetApiResource):
|
|||||||
|
|
||||||
query = query.order_by(desc(Document.created_at), desc(Document.position))
|
query = query.order_by(desc(Document.created_at), desc(Document.position))
|
||||||
|
|
||||||
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
documents = paginated_documents.items
|
documents = paginated_documents.items
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
@ -374,14 +374,20 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||||||
raise NotFound("Documents not found.")
|
raise NotFound("Documents not found.")
|
||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = DocumentSegment.query.filter(
|
completed_segments = (
|
||||||
DocumentSegment.completed_at.isnot(None),
|
db.session.query(DocumentSegment)
|
||||||
DocumentSegment.document_id == str(document.id),
|
.filter(
|
||||||
DocumentSegment.status != "re_segment",
|
DocumentSegment.completed_at.isnot(None),
|
||||||
).count()
|
DocumentSegment.document_id == str(document.id),
|
||||||
total_segments = DocumentSegment.query.filter(
|
DocumentSegment.status != "re_segment",
|
||||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
)
|
||||||
).count()
|
.count()
|
||||||
|
)
|
||||||
|
total_segments = (
|
||||||
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||||
|
.count()
|
||||||
|
)
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
if document.is_paused:
|
if document.is_paused:
|
||||||
|
@ -46,14 +46,22 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
DatasetDocument.id == document.metadata["document_id"]
|
DatasetDocument.id == document.metadata["document_id"]
|
||||||
).first()
|
).first()
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||||
child_chunk = ChildChunk.query.filter(
|
child_chunk = (
|
||||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
db.session.query(ChildChunk)
|
||||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
.filter(
|
||||||
ChildChunk.document_id == dataset_document.id,
|
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||||
).first()
|
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||||
|
ChildChunk.document_id == dataset_document.id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if child_chunk:
|
if child_chunk:
|
||||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
segment = (
|
||||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.id == child_chunk.segment_id)
|
||||||
|
.update(
|
||||||
|
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query = db.session.query(DocumentSegment).filter(
|
query = db.session.query(DocumentSegment).filter(
|
||||||
|
@ -51,7 +51,7 @@ class IndexingRunner:
|
|||||||
for dataset_document in dataset_documents:
|
for dataset_document in dataset_documents:
|
||||||
try:
|
try:
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
@ -103,15 +103,17 @@ class IndexingRunner:
|
|||||||
"""Run the indexing process when the index_status is splitting."""
|
"""Run the indexing process when the index_status is splitting."""
|
||||||
try:
|
try:
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
|
|
||||||
# get exist document_segment list and delete
|
# get exist document_segment list and delete
|
||||||
document_segments = DocumentSegment.query.filter_by(
|
document_segments = (
|
||||||
dataset_id=dataset.id, document_id=dataset_document.id
|
db.session.query(DocumentSegment)
|
||||||
).all()
|
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
for document_segment in document_segments:
|
for document_segment in document_segments:
|
||||||
db.session.delete(document_segment)
|
db.session.delete(document_segment)
|
||||||
@ -162,15 +164,17 @@ class IndexingRunner:
|
|||||||
"""Run the indexing process when the index_status is indexing."""
|
"""Run the indexing process when the index_status is indexing."""
|
||||||
try:
|
try:
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
|
|
||||||
# get exist document_segment list and delete
|
# get exist document_segment list and delete
|
||||||
document_segments = DocumentSegment.query.filter_by(
|
document_segments = (
|
||||||
dataset_id=dataset.id, document_id=dataset_document.id
|
db.session.query(DocumentSegment)
|
||||||
).all()
|
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
if document_segments:
|
if document_segments:
|
||||||
@ -254,7 +258,7 @@ class IndexingRunner:
|
|||||||
|
|
||||||
embedding_model_instance = None
|
embedding_model_instance = None
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset not found.")
|
raise ValueError("Dataset not found.")
|
||||||
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
|
||||||
@ -587,7 +591,7 @@ class IndexingRunner:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
|
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
keyword = Keyword(dataset)
|
keyword = Keyword(dataset)
|
||||||
@ -676,7 +680,7 @@ class IndexingRunner:
|
|||||||
"""
|
"""
|
||||||
Update the document segment by document id.
|
Update the document segment by document id.
|
||||||
"""
|
"""
|
||||||
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def _transform(
|
def _transform(
|
||||||
|
@ -237,7 +237,7 @@ class DatasetRetrieval:
|
|||||||
if show_retrieve_source:
|
if show_retrieve_source:
|
||||||
for record in records:
|
for record in records:
|
||||||
segment = record.segment
|
segment = record.segment
|
||||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||||
document = DatasetDocument.query.filter(
|
document = DatasetDocument.query.filter(
|
||||||
DatasetDocument.id == segment.document_id,
|
DatasetDocument.id == segment.document_id,
|
||||||
DatasetDocument.enabled == True,
|
DatasetDocument.enabled == True,
|
||||||
@ -511,14 +511,23 @@ class DatasetRetrieval:
|
|||||||
).first()
|
).first()
|
||||||
if dataset_document:
|
if dataset_document:
|
||||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||||
child_chunk = ChildChunk.query.filter(
|
child_chunk = (
|
||||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
db.session.query(ChildChunk)
|
||||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
.filter(
|
||||||
ChildChunk.document_id == dataset_document.id,
|
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||||
).first()
|
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||||
|
ChildChunk.document_id == dataset_document.id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if child_chunk:
|
if child_chunk:
|
||||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
segment = (
|
||||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(DocumentSegment.id == child_chunk.segment_id)
|
||||||
|
.update(
|
||||||
|
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||||
|
synchronize_session=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
else:
|
else:
|
||||||
|
@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
|
|
||||||
document_context_list = []
|
document_context_list = []
|
||||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||||
segments = DocumentSegment.query.filter(
|
segments = (
|
||||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
db.session.query(DocumentSegment)
|
||||||
DocumentSegment.completed_at.isnot(None),
|
.filter(
|
||||||
DocumentSegment.status == "completed",
|
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||||
DocumentSegment.enabled == True,
|
DocumentSegment.completed_at.isnot(None),
|
||||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
DocumentSegment.status == "completed",
|
||||||
).all()
|
DocumentSegment.enabled == True,
|
||||||
|
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
if segments:
|
if segments:
|
||||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||||
@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
context_list = []
|
context_list = []
|
||||||
resource_number = 1
|
resource_number = 1
|
||||||
for segment in sorted_segments:
|
for segment in sorted_segments:
|
||||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||||
document = Document.query.filter(
|
document = (
|
||||||
Document.id == segment.document_id,
|
db.session.query(Document)
|
||||||
Document.enabled == True,
|
.filter(
|
||||||
Document.archived == False,
|
Document.id == segment.document_id,
|
||||||
).first()
|
Document.enabled == True,
|
||||||
|
Document.archived == False,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if dataset and document:
|
if dataset and document:
|
||||||
source = {
|
source = {
|
||||||
"position": resource_number,
|
"position": resource_number,
|
||||||
|
@ -185,7 +185,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
if self.return_resource:
|
if self.return_resource:
|
||||||
for record in records:
|
for record in records:
|
||||||
segment = record.segment
|
segment = record.segment
|
||||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||||
document = DatasetDocument.query.filter(
|
document = DatasetDocument.query.filter(
|
||||||
DatasetDocument.id == segment.document_id,
|
DatasetDocument.id == segment.document_id,
|
||||||
DatasetDocument.enabled == True,
|
DatasetDocument.enabled == True,
|
||||||
|
@ -275,12 +275,16 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||||||
if records:
|
if records:
|
||||||
for record in records:
|
for record in records:
|
||||||
segment = record.segment
|
segment = record.segment
|
||||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||||
document = Document.query.filter(
|
document = (
|
||||||
Document.id == segment.document_id,
|
db.session.query(Document)
|
||||||
Document.enabled == True,
|
.filter(
|
||||||
Document.archived == False,
|
Document.id == segment.document_id,
|
||||||
).first()
|
Document.enabled == True,
|
||||||
|
Document.archived == False,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if dataset and document:
|
if dataset and document:
|
||||||
source = {
|
source = {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -93,7 +93,8 @@ class Dataset(Base):
|
|||||||
@property
|
@property
|
||||||
def latest_process_rule(self):
|
def latest_process_rule(self):
|
||||||
return (
|
return (
|
||||||
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
|
db.session.query(DatasetProcessRule)
|
||||||
|
.filter(DatasetProcessRule.dataset_id == self.id)
|
||||||
.order_by(DatasetProcessRule.created_at.desc())
|
.order_by(DatasetProcessRule.created_at.desc())
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@ -138,7 +139,8 @@ class Dataset(Base):
|
|||||||
@property
|
@property
|
||||||
def word_count(self):
|
def word_count(self):
|
||||||
return (
|
return (
|
||||||
Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
|
db.session.query(Document)
|
||||||
|
.with_entities(func.coalesce(func.sum(Document.word_count)))
|
||||||
.filter(Document.dataset_id == self.id)
|
.filter(Document.dataset_id == self.id)
|
||||||
.scalar()
|
.scalar()
|
||||||
)
|
)
|
||||||
@ -440,12 +442,13 @@ class Document(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def segment_count(self):
|
def segment_count(self):
|
||||||
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
|
return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hit_count(self):
|
def hit_count(self):
|
||||||
return (
|
return (
|
||||||
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
|
db.session.query(DocumentSegment)
|
||||||
|
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
|
||||||
.filter(DocumentSegment.document_id == self.id)
|
.filter(DocumentSegment.document_id == self.id)
|
||||||
.scalar()
|
.scalar()
|
||||||
)
|
)
|
||||||
@ -892,7 +895,7 @@ class DatasetKeywordTable(Base):
|
|||||||
return dct
|
return dct
|
||||||
|
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = Dataset.query.filter_by(id=self.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
return None
|
return None
|
||||||
if self.data_source_type == "database":
|
if self.data_source_type == "database":
|
||||||
|
@ -2,7 +2,7 @@ import datetime
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func, select
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
import app
|
import app
|
||||||
@ -51,8 +51,9 @@ def clean_unused_datasets_task():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Main query with join and filter
|
# Main query with join and filter
|
||||||
datasets = (
|
stmt = (
|
||||||
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
select(Dataset)
|
||||||
|
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||||
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
||||||
.filter(
|
.filter(
|
||||||
Dataset.created_at < plan_sandbox_clean_day,
|
Dataset.created_at < plan_sandbox_clean_day,
|
||||||
@ -60,9 +61,10 @@ def clean_unused_datasets_task():
|
|||||||
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
||||||
)
|
)
|
||||||
.order_by(Dataset.created_at.desc())
|
.order_by(Dataset.created_at.desc())
|
||||||
.paginate(page=1, per_page=50)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||||
|
|
||||||
except NotFound:
|
except NotFound:
|
||||||
break
|
break
|
||||||
if datasets.items is None or len(datasets.items) == 0:
|
if datasets.items is None or len(datasets.items) == 0:
|
||||||
@ -99,7 +101,7 @@ def clean_unused_datasets_task():
|
|||||||
# update document
|
# update document
|
||||||
update_params = {Document.enabled: False}
|
update_params = {Document.enabled: False}
|
||||||
|
|
||||||
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
|
db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
|
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -135,8 +137,9 @@ def clean_unused_datasets_task():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Main query with join and filter
|
# Main query with join and filter
|
||||||
datasets = (
|
stmt = (
|
||||||
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
select(Dataset)
|
||||||
|
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||||
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
||||||
.filter(
|
.filter(
|
||||||
Dataset.created_at < plan_pro_clean_day,
|
Dataset.created_at < plan_pro_clean_day,
|
||||||
@ -144,8 +147,8 @@ def clean_unused_datasets_task():
|
|||||||
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
||||||
)
|
)
|
||||||
.order_by(Dataset.created_at.desc())
|
.order_by(Dataset.created_at.desc())
|
||||||
.paginate(page=1, per_page=50)
|
|
||||||
)
|
)
|
||||||
|
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||||
|
|
||||||
except NotFound:
|
except NotFound:
|
||||||
break
|
break
|
||||||
@ -175,7 +178,7 @@ def clean_unused_datasets_task():
|
|||||||
# update document
|
# update document
|
||||||
update_params = {Document.enabled: False}
|
update_params = {Document.enabled: False}
|
||||||
|
|
||||||
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
|
db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")
|
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")
|
||||||
|
@ -19,7 +19,9 @@ def create_tidb_serverless_task():
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# check the number of idle tidb serverless
|
# check the number of idle tidb serverless
|
||||||
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count()
|
idle_tidb_serverless_number = (
|
||||||
|
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
|
||||||
|
)
|
||||||
if idle_tidb_serverless_number >= tidb_serverless_number:
|
if idle_tidb_serverless_number >= tidb_serverless_number:
|
||||||
break
|
break
|
||||||
# create tidb serverless
|
# create tidb serverless
|
||||||
|
@ -29,7 +29,9 @@ def mail_clean_document_notify_task():
|
|||||||
|
|
||||||
# send document clean notify mail
|
# send document clean notify mail
|
||||||
try:
|
try:
|
||||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
|
dataset_auto_disable_logs = (
|
||||||
|
db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
|
||||||
|
)
|
||||||
# group by tenant_id
|
# group by tenant_id
|
||||||
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
|
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
|
||||||
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
||||||
@ -65,7 +67,7 @@ def mail_clean_document_notify_task():
|
|||||||
)
|
)
|
||||||
|
|
||||||
for dataset_id, document_ids in dataset_auto_dataset_map.items():
|
for dataset_id, document_ids in dataset_auto_dataset_map.items():
|
||||||
dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||||
if dataset:
|
if dataset:
|
||||||
document_count = len(document_ids)
|
document_count = len(document_ids)
|
||||||
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
|
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
|
||||||
|
@ -5,6 +5,7 @@ import click
|
|||||||
import app
|
import app
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||||
|
from extensions.ext_database import db
|
||||||
from models.dataset import TidbAuthBinding
|
from models.dataset import TidbAuthBinding
|
||||||
|
|
||||||
|
|
||||||
@ -14,9 +15,11 @@ def update_tidb_serverless_status_task():
|
|||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
# check the number of idle tidb serverless
|
# check the number of idle tidb serverless
|
||||||
tidb_serverless_list = TidbAuthBinding.query.filter(
|
tidb_serverless_list = (
|
||||||
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING"
|
db.session.query(TidbAuthBinding)
|
||||||
).all()
|
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
||||||
|
.all()
|
||||||
|
)
|
||||||
if len(tidb_serverless_list) == 0:
|
if len(tidb_serverless_list) == 0:
|
||||||
return
|
return
|
||||||
# update tidb serverless status
|
# update tidb serverless status
|
||||||
|
@ -9,7 +9,7 @@ from collections import Counter
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
|
|||||||
class DatasetService:
|
class DatasetService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||||
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
# get permitted dataset ids
|
# get permitted dataset ids
|
||||||
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
dataset_permission = (
|
||||||
|
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
||||||
|
)
|
||||||
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
|
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
|
||||||
|
|
||||||
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
|
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
|
||||||
@ -129,7 +131,7 @@ class DatasetService:
|
|||||||
else:
|
else:
|
||||||
return [], 0
|
return [], 0
|
||||||
|
|
||||||
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||||
|
|
||||||
return datasets.items, datasets.total
|
return datasets.items, datasets.total
|
||||||
|
|
||||||
@ -153,9 +155,10 @@ class DatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_datasets_by_ids(ids, tenant_id):
|
def get_datasets_by_ids(ids, tenant_id):
|
||||||
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate(
|
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
|
||||||
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
|
|
||||||
)
|
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
||||||
|
|
||||||
return datasets.items, datasets.total
|
return datasets.items, datasets.total
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -174,7 +177,7 @@ class DatasetService:
|
|||||||
retrieval_model: Optional[RetrievalModel] = None,
|
retrieval_model: Optional[RetrievalModel] = None,
|
||||||
):
|
):
|
||||||
# check if dataset name already exists
|
# check if dataset name already exists
|
||||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||||
embedding_model = None
|
embedding_model = None
|
||||||
if indexing_technique == "high_quality":
|
if indexing_technique == "high_quality":
|
||||||
@ -235,7 +238,7 @@ class DatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||||
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
|
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -436,7 +439,7 @@ class DatasetService:
|
|||||||
# update Retrieval model
|
# update Retrieval model
|
||||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||||
|
|
||||||
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
if action:
|
if action:
|
||||||
@ -460,7 +463,7 @@ class DatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dataset_use_check(dataset_id) -> bool:
|
def dataset_use_check(dataset_id) -> bool:
|
||||||
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count()
|
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@ -475,7 +478,9 @@ class DatasetService:
|
|||||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
if dataset.permission == "partial_members":
|
if dataset.permission == "partial_members":
|
||||||
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
user_permission = (
|
||||||
|
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
not user_permission
|
not user_permission
|
||||||
and dataset.tenant_id != user.current_tenant_id
|
and dataset.tenant_id != user.current_tenant_id
|
||||||
@ -499,23 +504,24 @@ class DatasetService:
|
|||||||
|
|
||||||
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||||
if not any(
|
if not any(
|
||||||
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
|
dp.dataset_id == dataset.id
|
||||||
|
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
|
||||||
):
|
):
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
||||||
dataset_queries = (
|
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
|
||||||
DatasetQuery.query.filter_by(dataset_id=dataset_id)
|
|
||||||
.order_by(db.desc(DatasetQuery.created_at))
|
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||||
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
|
||||||
)
|
|
||||||
return dataset_queries.items, dataset_queries.total
|
return dataset_queries.items, dataset_queries.total
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_related_apps(dataset_id: str):
|
def get_related_apps(dataset_id: str):
|
||||||
return (
|
return (
|
||||||
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id)
|
db.session.query(AppDatasetJoin)
|
||||||
|
.filter(AppDatasetJoin.dataset_id == dataset_id)
|
||||||
.order_by(db.desc(AppDatasetJoin.created_at))
|
.order_by(db.desc(AppDatasetJoin.created_at))
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
@ -530,10 +536,14 @@ class DatasetService:
|
|||||||
}
|
}
|
||||||
# get recent 30 days auto disable logs
|
# get recent 30 days auto disable logs
|
||||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
dataset_auto_disable_logs = (
|
||||||
DatasetAutoDisableLog.dataset_id == dataset_id,
|
db.session.query(DatasetAutoDisableLog)
|
||||||
DatasetAutoDisableLog.created_at >= start_date,
|
.filter(
|
||||||
).all()
|
DatasetAutoDisableLog.dataset_id == dataset_id,
|
||||||
|
DatasetAutoDisableLog.created_at >= start_date,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
if dataset_auto_disable_logs:
|
if dataset_auto_disable_logs:
|
||||||
return {
|
return {
|
||||||
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
|
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
|
||||||
@ -873,7 +883,9 @@ class DocumentService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_documents_position(dataset_id):
|
def get_documents_position(dataset_id):
|
||||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
document = (
|
||||||
|
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||||
|
)
|
||||||
if document:
|
if document:
|
||||||
return document.position + 1
|
return document.position + 1
|
||||||
else:
|
else:
|
||||||
@ -1010,13 +1022,17 @@ class DocumentService:
|
|||||||
}
|
}
|
||||||
# check duplicate
|
# check duplicate
|
||||||
if knowledge_config.duplicate:
|
if knowledge_config.duplicate:
|
||||||
document = Document.query.filter_by(
|
document = (
|
||||||
dataset_id=dataset.id,
|
db.session.query(Document)
|
||||||
tenant_id=current_user.current_tenant_id,
|
.filter_by(
|
||||||
data_source_type="upload_file",
|
dataset_id=dataset.id,
|
||||||
enabled=True,
|
tenant_id=current_user.current_tenant_id,
|
||||||
name=file_name,
|
data_source_type="upload_file",
|
||||||
).first()
|
enabled=True,
|
||||||
|
name=file_name,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if document:
|
if document:
|
||||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||||
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
@ -1054,12 +1070,16 @@ class DocumentService:
|
|||||||
raise ValueError("No notion info list found.")
|
raise ValueError("No notion info list found.")
|
||||||
exist_page_ids = []
|
exist_page_ids = []
|
||||||
exist_document = {}
|
exist_document = {}
|
||||||
documents = Document.query.filter_by(
|
documents = (
|
||||||
dataset_id=dataset.id,
|
db.session.query(Document)
|
||||||
tenant_id=current_user.current_tenant_id,
|
.filter_by(
|
||||||
data_source_type="notion_import",
|
dataset_id=dataset.id,
|
||||||
enabled=True,
|
tenant_id=current_user.current_tenant_id,
|
||||||
).all()
|
data_source_type="notion_import",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
if documents:
|
if documents:
|
||||||
for document in documents:
|
for document in documents:
|
||||||
data_source_info = json.loads(document.data_source_info)
|
data_source_info = json.loads(document.data_source_info)
|
||||||
@ -1206,12 +1226,16 @@ class DocumentService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tenant_documents_count():
|
def get_tenant_documents_count():
|
||||||
documents_count = Document.query.filter(
|
documents_count = (
|
||||||
Document.completed_at.isnot(None),
|
db.session.query(Document)
|
||||||
Document.enabled == True,
|
.filter(
|
||||||
Document.archived == False,
|
Document.completed_at.isnot(None),
|
||||||
Document.tenant_id == current_user.current_tenant_id,
|
Document.enabled == True,
|
||||||
).count()
|
Document.archived == False,
|
||||||
|
Document.tenant_id == current_user.current_tenant_id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
return documents_count
|
return documents_count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1328,7 +1352,7 @@ class DocumentService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
# update document segment
|
# update document segment
|
||||||
update_params = {DocumentSegment.status: "re_segment"}
|
update_params = {DocumentSegment.status: "re_segment"}
|
||||||
DocumentSegment.query.filter_by(document_id=document.id).update(update_params)
|
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
# trigger async task
|
# trigger async task
|
||||||
document_indexing_update_task.delay(document.dataset_id, document.id)
|
document_indexing_update_task.delay(document.dataset_id, document.id)
|
||||||
@ -1918,7 +1942,8 @@ class SegmentService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
||||||
index_node_ids = (
|
index_node_ids = (
|
||||||
DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
|
db.session.query(DocumentSegment)
|
||||||
|
.with_entities(DocumentSegment.index_node_id)
|
||||||
.filter(
|
.filter(
|
||||||
DocumentSegment.id.in_(segment_ids),
|
DocumentSegment.id.in_(segment_ids),
|
||||||
DocumentSegment.dataset_id == dataset.id,
|
DocumentSegment.dataset_id == dataset.id,
|
||||||
@ -2157,20 +2182,28 @@ class SegmentService:
|
|||||||
def get_child_chunks(
|
def get_child_chunks(
|
||||||
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
||||||
):
|
):
|
||||||
query = ChildChunk.query.filter_by(
|
query = (
|
||||||
tenant_id=current_user.current_tenant_id,
|
select(ChildChunk)
|
||||||
dataset_id=dataset_id,
|
.filter_by(
|
||||||
document_id=document_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
segment_id=segment_id,
|
dataset_id=dataset_id,
|
||||||
).order_by(ChildChunk.position.asc())
|
document_id=document_id,
|
||||||
|
segment_id=segment_id,
|
||||||
|
)
|
||||||
|
.order_by(ChildChunk.position.asc())
|
||||||
|
)
|
||||||
if keyword:
|
if keyword:
|
||||||
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
||||||
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
|
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
|
||||||
"""Get a child chunk by its ID."""
|
"""Get a child chunk by its ID."""
|
||||||
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first()
|
result = (
|
||||||
|
db.session.query(ChildChunk)
|
||||||
|
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
return result if isinstance(result, ChildChunk) else None
|
return result if isinstance(result, ChildChunk) else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -2184,7 +2217,7 @@ class SegmentService:
|
|||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
):
|
):
|
||||||
"""Get segments for a document with optional filtering."""
|
"""Get segments for a document with optional filtering."""
|
||||||
query = DocumentSegment.query.filter(
|
query = select(DocumentSegment).filter(
|
||||||
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
|
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2194,9 +2227,8 @@ class SegmentService:
|
|||||||
if keyword:
|
if keyword:
|
||||||
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
|
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||||
|
|
||||||
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
|
query = query.order_by(DocumentSegment.position.asc())
|
||||||
page=page, per_page=limit, max_per_page=100, error_out=False
|
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
)
|
|
||||||
|
|
||||||
return paginated_segments.items, paginated_segments.total
|
return paginated_segments.items, paginated_segments.total
|
||||||
|
|
||||||
@ -2236,9 +2268,11 @@ class SegmentService:
|
|||||||
raise ValueError(ex.description)
|
raise ValueError(ex.description)
|
||||||
|
|
||||||
# check segment
|
# check segment
|
||||||
segment = DocumentSegment.query.filter(
|
segment = (
|
||||||
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
|
|
||||||
@ -2251,9 +2285,11 @@ class SegmentService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
|
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
|
||||||
"""Get a segment by its ID."""
|
"""Get a segment by its ID."""
|
||||||
result = DocumentSegment.query.filter(
|
result = (
|
||||||
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id
|
db.session.query(DocumentSegment)
|
||||||
).first()
|
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
return result if isinstance(result, DocumentSegment) else None
|
return result if isinstance(result, DocumentSegment) else None
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError
|
|||||||
|
|
||||||
class ExternalDatasetService:
|
class ExternalDatasetService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
|
def get_external_knowledge_apis(
|
||||||
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by(
|
page, per_page, tenant_id, search=None
|
||||||
ExternalKnowledgeApis.created_at.desc()
|
) -> tuple[list[ExternalKnowledgeApis], int | None]:
|
||||||
|
query = (
|
||||||
|
select(ExternalKnowledgeApis)
|
||||||
|
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||||
|
.order_by(ExternalKnowledgeApis.created_at.desc())
|
||||||
)
|
)
|
||||||
if search:
|
if search:
|
||||||
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
|
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
|
||||||
|
|
||||||
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
external_knowledge_apis = db.paginate(
|
||||||
|
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
|
||||||
|
)
|
||||||
|
|
||||||
return external_knowledge_apis.items, external_knowledge_apis.total
|
return external_knowledge_apis.items, external_knowledge_apis.total
|
||||||
|
|
||||||
@ -92,18 +99,18 @@ class ExternalDatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
|
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
|
||||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
|
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
|
||||||
id=external_knowledge_api_id
|
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
|
||||||
).first()
|
)
|
||||||
if external_knowledge_api is None:
|
if external_knowledge_api is None:
|
||||||
raise ValueError("api template not found")
|
raise ValueError("api template not found")
|
||||||
return external_knowledge_api
|
return external_knowledge_api
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
||||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
|
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
|
||||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||||
).first()
|
)
|
||||||
if external_knowledge_api is None:
|
if external_knowledge_api is None:
|
||||||
raise ValueError("api template not found")
|
raise ValueError("api template not found")
|
||||||
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
|
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
|
||||||
@ -120,9 +127,9 @@ class ExternalDatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
|
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
|
||||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
external_knowledge_api = (
|
||||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||||
).first()
|
)
|
||||||
if external_knowledge_api is None:
|
if external_knowledge_api is None:
|
||||||
raise ValueError("api template not found")
|
raise ValueError("api template not found")
|
||||||
|
|
||||||
@ -131,25 +138,29 @@ class ExternalDatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||||
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count()
|
count = (
|
||||||
|
db.session.query(ExternalKnowledgeBindings)
|
||||||
|
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return True, count
|
return True, count
|
||||||
return False, 0
|
return False, 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||||
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
|
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
|
||||||
dataset_id=dataset_id, tenant_id=tenant_id
|
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||||
).first()
|
)
|
||||||
if not external_knowledge_binding:
|
if not external_knowledge_binding:
|
||||||
raise ValueError("external knowledge binding not found")
|
raise ValueError("external knowledge binding not found")
|
||||||
return external_knowledge_binding
|
return external_knowledge_binding
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
||||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
external_knowledge_api = (
|
||||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||||
).first()
|
)
|
||||||
if external_knowledge_api is None:
|
if external_knowledge_api is None:
|
||||||
raise ValueError("api template not found")
|
raise ValueError("api template not found")
|
||||||
settings = json.loads(external_knowledge_api.settings)
|
settings = json.loads(external_knowledge_api.settings)
|
||||||
@ -212,11 +223,13 @@ class ExternalDatasetService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||||
# check if dataset name already exists
|
# check if dataset name already exists
|
||||||
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
||||||
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
|
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
|
||||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
external_knowledge_api = (
|
||||||
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id
|
db.session.query(ExternalKnowledgeApis)
|
||||||
).first()
|
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if external_knowledge_api is None:
|
if external_knowledge_api is None:
|
||||||
raise ValueError("api template not found")
|
raise ValueError("api template not found")
|
||||||
@ -254,15 +267,17 @@ class ExternalDatasetService:
|
|||||||
external_retrieval_parameters: dict,
|
external_retrieval_parameters: dict,
|
||||||
metadata_condition: Optional[MetadataCondition] = None,
|
metadata_condition: Optional[MetadataCondition] = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
external_knowledge_binding = (
|
||||||
dataset_id=dataset_id, tenant_id=tenant_id
|
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||||
).first()
|
)
|
||||||
if not external_knowledge_binding:
|
if not external_knowledge_binding:
|
||||||
raise ValueError("external knowledge binding not found")
|
raise ValueError("external knowledge binding not found")
|
||||||
|
|
||||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
external_knowledge_api = (
|
||||||
id=external_knowledge_binding.external_knowledge_api_id
|
db.session.query(ExternalKnowledgeApis)
|
||||||
).first()
|
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not external_knowledge_api:
|
if not external_knowledge_api:
|
||||||
raise ValueError("external api template not found")
|
raise ValueError("external api template not found")
|
||||||
|
|
||||||
|
@ -20,9 +20,11 @@ class MetadataService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
|
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
|
||||||
# check if metadata name already exists
|
# check if metadata name already exists
|
||||||
if DatasetMetadata.query.filter_by(
|
if (
|
||||||
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
|
db.session.query(DatasetMetadata)
|
||||||
).first():
|
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
|
||||||
|
.first()
|
||||||
|
):
|
||||||
raise ValueError("Metadata name already exists.")
|
raise ValueError("Metadata name already exists.")
|
||||||
for field in BuiltInField:
|
for field in BuiltInField:
|
||||||
if field.value == metadata_args.name:
|
if field.value == metadata_args.name:
|
||||||
@ -42,16 +44,18 @@ class MetadataService:
|
|||||||
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
|
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
|
||||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||||
# check if metadata name already exists
|
# check if metadata name already exists
|
||||||
if DatasetMetadata.query.filter_by(
|
if (
|
||||||
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
|
db.session.query(DatasetMetadata)
|
||||||
).first():
|
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
|
||||||
|
.first()
|
||||||
|
):
|
||||||
raise ValueError("Metadata name already exists.")
|
raise ValueError("Metadata name already exists.")
|
||||||
for field in BuiltInField:
|
for field in BuiltInField:
|
||||||
if field.value == name:
|
if field.value == name:
|
||||||
raise ValueError("Metadata name already exists in Built-in fields.")
|
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||||
try:
|
try:
|
||||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
raise ValueError("Metadata not found.")
|
raise ValueError("Metadata not found.")
|
||||||
old_name = metadata.name
|
old_name = metadata.name
|
||||||
@ -60,7 +64,9 @@ class MetadataService:
|
|||||||
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
# update related documents
|
# update related documents
|
||||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
dataset_metadata_bindings = (
|
||||||
|
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||||
|
)
|
||||||
if dataset_metadata_bindings:
|
if dataset_metadata_bindings:
|
||||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||||
documents = DocumentService.get_document_by_ids(document_ids)
|
documents = DocumentService.get_document_by_ids(document_ids)
|
||||||
@ -82,13 +88,15 @@ class MetadataService:
|
|||||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||||
try:
|
try:
|
||||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
raise ValueError("Metadata not found.")
|
raise ValueError("Metadata not found.")
|
||||||
db.session.delete(metadata)
|
db.session.delete(metadata)
|
||||||
|
|
||||||
# deal related documents
|
# deal related documents
|
||||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
dataset_metadata_bindings = (
|
||||||
|
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||||
|
)
|
||||||
if dataset_metadata_bindings:
|
if dataset_metadata_bindings:
|
||||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||||
documents = DocumentService.get_document_by_ids(document_ids)
|
documents = DocumentService.get_document_by_ids(document_ids)
|
||||||
@ -193,7 +201,7 @@ class MetadataService:
|
|||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
# deal metadata binding
|
# deal metadata binding
|
||||||
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
|
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
|
||||||
for metadata_value in operation.metadata_list:
|
for metadata_value in operation.metadata_list:
|
||||||
dataset_metadata_binding = DatasetMetadataBinding(
|
dataset_metadata_binding = DatasetMetadataBinding(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
@ -230,9 +238,9 @@ class MetadataService:
|
|||||||
"id": item.get("id"),
|
"id": item.get("id"),
|
||||||
"name": item.get("name"),
|
"name": item.get("name"),
|
||||||
"type": item.get("type"),
|
"type": item.get("type"),
|
||||||
"count": DatasetMetadataBinding.query.filter_by(
|
"count": db.session.query(DatasetMetadataBinding)
|
||||||
metadata_id=item.get("id"), dataset_id=dataset.id
|
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
|
||||||
).count(),
|
.count(),
|
||||||
}
|
}
|
||||||
for item in dataset.doc_metadata or []
|
for item in dataset.doc_metadata or []
|
||||||
if item.get("id") != "built-in"
|
if item.get("id") != "built-in"
|
||||||
|
@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
|
|||||||
DocumentSegment.status: "indexing",
|
DocumentSegment.status: "indexing",
|
||||||
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||||
}
|
}
|
||||||
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
|
db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
document = Document(
|
document = Document(
|
||||||
page_content=segment.content,
|
page_content=segment.content,
|
||||||
@ -78,7 +78,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
|
|||||||
DocumentSegment.status: "completed",
|
DocumentSegment.status: "completed",
|
||||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||||
}
|
}
|
||||||
DocumentSegment.query.filter_by(id=segment.id).update(update_params)
|
db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
|
@ -24,7 +24,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception("Dataset not found")
|
raise Exception("Dataset not found")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user