diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 32b30341c..9bbe827e2 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -40,6 +40,12 @@ def create(tenant_id): kb = kbs[0] if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") + + # Check if all documents in the knowledge base have been parsed + is_done, error_msg = KnowledgebaseService.is_parsed_done(kb_id) + if not is_done: + return get_error_data_result(error_msg) + kbs = KnowledgebaseService.get_by_ids(ids) if ids else [] embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison embd_count = list(set(embd_ids)) @@ -176,6 +182,12 @@ def update(tenant_id, chat_id): kb = kbs[0] if kb.chunk_num == 0: return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") + + # Check if all documents in the knowledge base have been parsed + is_done, error_msg = KnowledgebaseService.is_parsed_done(kb_id) + if not is_done: + return get_error_data_result(error_msg) + kbs = KnowledgebaseService.get_by_ids(ids) embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison embd_count = list(set(embd_ids)) diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index a4f5d0095..f4567cddf 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -22,6 +22,42 @@ from peewee import fn class KnowledgebaseService(CommonService): model = Knowledgebase + @classmethod + @DB.connection_context() + def is_parsed_done(cls, kb_id): + """ + Check if all documents in the knowledge base have completed parsing + + Args: + kb_id: Knowledge base ID + + Returns: + If all documents are parsed successfully, returns (True, None) + If any document is not fully parsed, returns (False, error_message) + """ + from api.db import TaskStatus + from api.db.services.document_service import DocumentService + + # Get knowledge base information + kbs = cls.query(id=kb_id) + if not kbs: + return False, "Knowledge base not found" + kb = kbs[0] + + # Get all documents in the knowledge base + docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "") + + # Check parsing status of each document + for doc in docs: + # If document is being parsed, don't allow chat creation + if doc['run'] == TaskStatus.RUNNING.value or doc['run'] == TaskStatus.CANCEL.value or doc['run'] == TaskStatus.FAIL.value: + return False, f"Document '{doc['name']}' in dataset '{kb.name}' is still being parsed. Please wait until all documents are parsed before starting a chat." + # If document is not yet parsed and has no chunks, don't allow chat creation + if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0: + return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat." + + return True, None + @classmethod @DB.connection_context() def list_documents_by_ids(cls,kb_ids):