Fix: fix document concurrent upload issue (#6095)

### What problem does this PR solve?

Resolve document concurrent upload issue. #6039 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Yongteng Lei 2025-03-14 16:31:44 +08:00 committed by GitHub
parent 9d94acbedb
commit d7774cf049
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 59 deletions

View File

@ -13,33 +13,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import xxhash
import json import json
import logging
import random import random
import re import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
import trio
import trio
import xxhash
from peewee import fn from peewee import fn
from api.db.db_utils import bulk_insert_into_db
from api import settings from api import settings
from api.utils import current_timestamp, get_format_time, get_uuid from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus
from rag.settings import SVR_QUEUE_NAME from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant
from rag.utils.storage_factory import STORAGE_IMPL from api.db.db_utils import bulk_insert_into_db
from rag.nlp import search, rag_tokenizer
from api.db import FileType, TaskStatus, ParserType, LLMType
from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
from api.db.db_models import Document
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db import StatusEnum from api.utils import current_timestamp, get_format_time, get_uuid
from rag.nlp import rag_tokenizer, search
from rag.settings import SVR_QUEUE_NAME
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
class DocumentService(CommonService): class DocumentService(CommonService):
@ -96,9 +93,7 @@ class DocumentService(CommonService):
def insert(cls, doc): def insert(cls, doc):
if not cls.save(**doc): if not cls.save(**doc):
raise RuntimeError("Database error (Document)!") raise RuntimeError("Database error (Document)!")
e, kb = KnowledgebaseService.get_by_id(doc["kb_id"]) if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]):
if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num + 1}):
raise RuntimeError("Database error (Knowledgebase)!") raise RuntimeError("Database error (Knowledgebase)!")
return Document(**doc) return Document(**doc)
@ -174,9 +169,9 @@ class DocumentService(CommonService):
"Document not found which is supposed to be there") "Document not found which is supposed to be there")
num = Knowledgebase.update( num = Knowledgebase.update(
token_num=Knowledgebase.token_num + token_num=Knowledgebase.token_num +
token_num, token_num,
chunk_num=Knowledgebase.chunk_num + chunk_num=Knowledgebase.chunk_num +
chunk_num).where( chunk_num).where(
Knowledgebase.id == kb_id).execute() Knowledgebase.id == kb_id).execute()
return num return num
@ -192,9 +187,9 @@ class DocumentService(CommonService):
"Document not found which is supposed to be there") "Document not found which is supposed to be there")
num = Knowledgebase.update( num = Knowledgebase.update(
token_num=Knowledgebase.token_num - token_num=Knowledgebase.token_num -
token_num, token_num,
chunk_num=Knowledgebase.chunk_num - chunk_num=Knowledgebase.chunk_num -
chunk_num chunk_num
).where( ).where(
Knowledgebase.id == kb_id).execute() Knowledgebase.id == kb_id).execute()
return num return num
@ -207,9 +202,9 @@ class DocumentService(CommonService):
num = Knowledgebase.update( num = Knowledgebase.update(
token_num=Knowledgebase.token_num - token_num=Knowledgebase.token_num -
doc.token_num, doc.token_num,
chunk_num=Knowledgebase.chunk_num - chunk_num=Knowledgebase.chunk_num -
doc.chunk_num, doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1 doc_num=Knowledgebase.doc_num - 1
).where( ).where(
Knowledgebase.id == doc.kb_id).execute() Knowledgebase.id == doc.kb_id).execute()
@ -221,7 +216,7 @@ class DocumentService(CommonService):
docs = cls.model.select( docs = cls.model.select(
Knowledgebase.tenant_id).join( Knowledgebase.tenant_id).join(
Knowledgebase, on=( Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where( Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
@ -243,7 +238,7 @@ class DocumentService(CommonService):
docs = cls.model.select( docs = cls.model.select(
Knowledgebase.tenant_id).join( Knowledgebase.tenant_id).join(
Knowledgebase, on=( Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where( Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
@ -256,7 +251,7 @@ class DocumentService(CommonService):
docs = cls.model.select( docs = cls.model.select(
cls.model.id).join( cls.model.id).join(
Knowledgebase, on=( Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id) Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts() docs = docs.dicts()
@ -270,7 +265,7 @@ class DocumentService(CommonService):
docs = cls.model.select( docs = cls.model.select(
cls.model.id).join( cls.model.id).join(
Knowledgebase, on=( Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id) Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1) ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
@ -283,7 +278,7 @@ class DocumentService(CommonService):
docs = cls.model.select( docs = cls.model.select(
Knowledgebase.embd_id).join( Knowledgebase.embd_id).join(
Knowledgebase, on=( Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where( Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
@ -306,9 +301,9 @@ class DocumentService(CommonService):
Tenant.asr_id, Tenant.asr_id,
Tenant.llm_id, Tenant.llm_id,
) )
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id) .where(cls.model.id == doc_id)
) )
configs = configs.dicts() configs = configs.dicts()
if not configs: if not configs:
@ -374,6 +369,7 @@ class DocumentService(CommonService):
"progress_msg": "Task is queued...", "progress_msg": "Task is queued...",
"process_begin_at": get_format_time() "process_begin_at": get_format_time()
}) })
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_meta_fields(cls, doc_id, meta_fields): def update_meta_fields(cls, doc_id, meta_fields):
@ -425,7 +421,7 @@ class DocumentService(CommonService):
info = { info = {
"process_duation": datetime.timestamp( "process_duation": datetime.timestamp(
datetime.now()) - datetime.now()) -
d["process_begin_at"].timestamp(), d["process_begin_at"].timestamp(),
"run": status} "run": status}
if prg != 0: if prg != 0:
info["progress"] = prg info["progress"] = prg
@ -480,13 +476,13 @@ def queue_raptor_o_graphrag_tasks(doc, ty):
def doc_upload_and_parse(conversation_id, file_objs, user_id): def doc_upload_and_parse(conversation_id, file_objs, user_id):
from rag.app import presentation, picture, naive, audio, email from api.db.services.api_service import API4ConversationService
from api.db.services.conversation_service import ConversationService
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.db.services.api_service import API4ConversationService from rag.app import audio, email, naive, picture, presentation
from api.db.services.conversation_service import ConversationService
e, conv = ConversationService.get_by_id(conversation_id) e, conv = ConversationService.get_by_id(conversation_id)
if not e: if not e:

View File

@ -13,26 +13,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from api.db import StatusEnum, TenantPermission from datetime import datetime
from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant,Document
from api.db.services.common_service import CommonService
from peewee import fn from peewee import fn
from api.db import StatusEnum, TenantPermission
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format
class KnowledgebaseService(CommonService): class KnowledgebaseService(CommonService):
"""Service class for managing knowledge base operations. """Service class for managing knowledge base operations.
This class extends CommonService to provide specialized functionality for knowledge base This class extends CommonService to provide specialized functionality for knowledge base
management, including document parsing status tracking, access control, and configuration management, including document parsing status tracking, access control, and configuration
management. It handles operations such as listing, creating, updating, and deleting management. It handles operations such as listing, creating, updating, and deleting
knowledge bases, as well as managing their associated documents and permissions. knowledge bases, as well as managing their associated documents and permissions.
The class implements a comprehensive set of methods for: The class implements a comprehensive set of methods for:
- Document parsing status verification - Document parsing status verification
- Knowledge base access control - Knowledge base access control
- Parser configuration management - Parser configuration management
- Tenant-based knowledge base organization - Tenant-based knowledge base organization
Attributes: Attributes:
model: The Knowledgebase model class for database operations. model: The Knowledgebase model class for database operations.
""" """
@ -42,22 +46,22 @@ class KnowledgebaseService(CommonService):
@DB.connection_context() @DB.connection_context()
def accessible4deletion(cls, kb_id, user_id): def accessible4deletion(cls, kb_id, user_id):
"""Check if a knowledge base can be deleted by a specific user. """Check if a knowledge base can be deleted by a specific user.
This method verifies whether a user has permission to delete a knowledge base This method verifies whether a user has permission to delete a knowledge base
by checking if they are the creator of that knowledge base. by checking if they are the creator of that knowledge base.
Args: Args:
kb_id (str): The unique identifier of the knowledge base to check. kb_id (str): The unique identifier of the knowledge base to check.
user_id (str): The unique identifier of the user attempting the deletion. user_id (str): The unique identifier of the user attempting the deletion.
Returns: Returns:
bool: True if the user has permission to delete the knowledge base, bool: True if the user has permission to delete the knowledge base,
False if the user doesn't have permission or the knowledge base doesn't exist. False if the user doesn't have permission or the knowledge base doesn't exist.
Example: Example:
>>> KnowledgebaseService.accessible4deletion("kb123", "user456") >>> KnowledgebaseService.accessible4deletion("kb123", "user456")
True True
Note: Note:
- This method only checks creator permissions - This method only checks creator permissions
- A return value of False can mean either: - A return value of False can mean either:
@ -76,25 +80,25 @@ class KnowledgebaseService(CommonService):
@DB.connection_context() @DB.connection_context()
def is_parsed_done(cls, kb_id): def is_parsed_done(cls, kb_id):
# Check if all documents in the knowledge base have completed parsing # Check if all documents in the knowledge base have completed parsing
# #
# Args: # Args:
# kb_id: Knowledge base ID # kb_id: Knowledge base ID
# #
# Returns: # Returns:
# If all documents are parsed successfully, returns (True, None) # If all documents are parsed successfully, returns (True, None)
# If any document is not fully parsed, returns (False, error_message) # If any document is not fully parsed, returns (False, error_message)
from api.db import TaskStatus from api.db import TaskStatus
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
# Get knowledge base information # Get knowledge base information
kbs = cls.query(id=kb_id) kbs = cls.query(id=kb_id)
if not kbs: if not kbs:
return False, "Knowledge base not found" return False, "Knowledge base not found"
kb = kbs[0] kb = kbs[0]
# Get all documents in the knowledge base # Get all documents in the knowledge base
docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "") docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "")
# Check parsing status of each document # Check parsing status of each document
for doc in docs: for doc in docs:
# If document is being parsed, don't allow chat creation # If document is being parsed, don't allow chat creation
@ -103,21 +107,21 @@ class KnowledgebaseService(CommonService):
# If document is not yet parsed and has no chunks, don't allow chat creation # If document is not yet parsed and has no chunks, don't allow chat creation
if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0: if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0:
return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat." return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat."
return True, None return True, None
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def list_documents_by_ids(cls,kb_ids): def list_documents_by_ids(cls, kb_ids):
# Get document IDs associated with given knowledge base IDs # Get document IDs associated with given knowledge base IDs
# Args: # Args:
# kb_ids: List of knowledge base IDs # kb_ids: List of knowledge base IDs
# Returns: # Returns:
# List of document IDs # List of document IDs
doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where( doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
cls.model.id.in_(kb_ids) cls.model.id.in_(kb_ids)
) )
doc_ids =list(doc_ids.dicts()) doc_ids = list(doc_ids.dicts())
doc_ids = [doc["document_id"] for doc in doc_ids] doc_ids = [doc["document_id"] for doc in doc_ids]
return doc_ids return doc_ids
@ -222,7 +226,7 @@ class KnowledgebaseService(CommonService):
cls.model.parser_config, cls.model.parser_config,
cls.model.pagerank] cls.model.pagerank]
kbs = cls.model.select(*fields).join(Tenant, on=( kbs = cls.model.select(*fields).join(Tenant, on=(
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
(cls.model.id == kb_id), (cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value) (cls.model.status == StatusEnum.VALID.value)
) )
@ -324,7 +328,7 @@ class KnowledgebaseService(CommonService):
kbs = kbs.where( kbs = kbs.where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | ( TenantPermission.TEAM.value)) | (
cls.model.tenant_id == user_id)) cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value) & (cls.model.status == StatusEnum.VALID.value)
) )
if desc: if desc:
@ -347,7 +351,7 @@ class KnowledgebaseService(CommonService):
# Boolean indicating accessibility # Boolean indicating accessibility
docs = cls.model.select( docs = cls.model.select(
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts() docs = docs.dicts()
if not docs: if not docs:
return False return False
@ -363,7 +367,7 @@ class KnowledgebaseService(CommonService):
# Returns: # Returns:
# List containing knowledge base information # List containing knowledge base information
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
kbs = kbs.dicts() kbs = kbs.dicts()
return list(kbs) return list(kbs)
@ -377,7 +381,16 @@ class KnowledgebaseService(CommonService):
# Returns: # Returns:
# List containing knowledge base information # List containing knowledge base information
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
kbs = kbs.dicts() kbs = kbs.dicts()
return list(kbs) return list(kbs)
@classmethod
@DB.connection_context()
def atomic_increase_doc_num_by_id(cls, kb_id):
data = {}
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
data["doc_num"] = cls.model.doc_num + 1
num = cls.model.update(data).where(cls.model.id == kb_id).execute()
return num