diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 10cba99c2..b0c560e32 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -13,33 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging -import xxhash import json +import logging import random import re from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from datetime import datetime from io import BytesIO -import trio +import trio +import xxhash from peewee import fn -from api.db.db_utils import bulk_insert_into_db from api import settings -from api.utils import current_timestamp, get_format_time, get_uuid -from rag.settings import SVR_QUEUE_NAME -from rag.utils.storage_factory import STORAGE_IMPL -from rag.nlp import search, rag_tokenizer - -from api.db import FileType, TaskStatus, ParserType, LLMType -from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant -from api.db.db_models import Document +from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus +from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant +from api.db.db_utils import bulk_insert_into_db from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db import StatusEnum +from api.utils import current_timestamp, get_format_time, get_uuid +from rag.nlp import rag_tokenizer, search +from rag.settings import SVR_QUEUE_NAME from rag.utils.redis_conn import REDIS_CONN +from rag.utils.storage_factory import STORAGE_IMPL class DocumentService(CommonService): @@ -96,9 +93,7 @@ class DocumentService(CommonService): def insert(cls, doc): if not cls.save(**doc): raise RuntimeError("Database error (Document)!") - e, kb = KnowledgebaseService.get_by_id(doc["kb_id"]) - if not KnowledgebaseService.update_by_id( - kb.id, {"doc_num": kb.doc_num + 1}): + if not KnowledgebaseService.atomic_increase_doc_num_by_id(doc["kb_id"]): raise RuntimeError("Database error (Knowledgebase)!") return Document(**doc) @@ -174,9 +169,9 @@ class DocumentService(CommonService): "Document not found which is supposed to be there") num = Knowledgebase.update( token_num=Knowledgebase.token_num + - token_num, + token_num, chunk_num=Knowledgebase.chunk_num + - chunk_num).where( + chunk_num).where( Knowledgebase.id == kb_id).execute() return num @@ -192,9 +187,9 @@ class DocumentService(CommonService): "Document not found which is supposed to be there") num = Knowledgebase.update( token_num=Knowledgebase.token_num - - token_num, + token_num, chunk_num=Knowledgebase.chunk_num - - chunk_num + chunk_num ).where( Knowledgebase.id == kb_id).execute() return num @@ -207,9 +202,9 @@ class DocumentService(CommonService): num = Knowledgebase.update( token_num=Knowledgebase.token_num - - doc.token_num, + doc.token_num, chunk_num=Knowledgebase.chunk_num - - doc.chunk_num, + doc.chunk_num, doc_num=Knowledgebase.doc_num - 1 ).where( Knowledgebase.id == doc.kb_id).execute() @@ -221,7 +216,7 @@ class DocumentService(CommonService): docs = cls.model.select( Knowledgebase.tenant_id).join( Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( + Knowledgebase.id == cls.model.kb_id)).where( cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: @@ -243,7 +238,7 @@ class DocumentService(CommonService): docs = cls.model.select( Knowledgebase.tenant_id).join( Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( + Knowledgebase.id == cls.model.kb_id)).where( cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: @@ -256,7 +251,7 @@ class DocumentService(CommonService): docs = cls.model.select( cls.model.id).join( Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) + Knowledgebase.id == cls.model.kb_id) ).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) ).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1) docs = docs.dicts() @@ -270,7 +265,7 @@ class DocumentService(CommonService): docs = cls.model.select( cls.model.id).join( Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id) + Knowledgebase.id == cls.model.kb_id) ).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1) docs = docs.dicts() if not docs: @@ -283,7 +278,7 @@ class DocumentService(CommonService): docs = cls.model.select( Knowledgebase.embd_id).join( Knowledgebase, on=( - Knowledgebase.id == cls.model.kb_id)).where( + Knowledgebase.id == cls.model.kb_id)).where( cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: @@ -306,9 +301,9 @@ class DocumentService(CommonService): Tenant.asr_id, Tenant.llm_id, ) - .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) - .where(cls.model.id == doc_id) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) + .where(cls.model.id == doc_id) ) configs = configs.dicts() if not configs: @@ -374,6 +369,7 @@ class DocumentService(CommonService): "progress_msg": "Task is queued...", "process_begin_at": get_format_time() }) + @classmethod @DB.connection_context() def update_meta_fields(cls, doc_id, meta_fields): @@ -425,7 +421,7 @@ class DocumentService(CommonService): info = { "process_duation": datetime.timestamp( datetime.now()) - - d["process_begin_at"].timestamp(), + d["process_begin_at"].timestamp(), "run": status} if prg != 0: info["progress"] = prg @@ -480,13 +476,13 @@ def queue_raptor_o_graphrag_tasks(doc, ty): def doc_upload_and_parse(conversation_id, file_objs, user_id): - from rag.app import presentation, picture, naive, audio, email + from api.db.services.api_service import API4ConversationService + from api.db.services.conversation_service import ConversationService from api.db.services.dialog_service import DialogService from api.db.services.file_service import FileService from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService - from api.db.services.api_service import API4ConversationService - from api.db.services.conversation_service import ConversationService + from rag.app import audio, email, naive, picture, presentation e, conv = ConversationService.get_by_id(conversation_id) if not e: diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index b9fa56e03..979cba60b 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -13,26 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from api.db import StatusEnum, TenantPermission -from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant,Document -from api.db.services.common_service import CommonService +from datetime import datetime + from peewee import fn +from api.db import StatusEnum, TenantPermission +from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant +from api.db.services.common_service import CommonService +from api.utils import current_timestamp, datetime_format + class KnowledgebaseService(CommonService): """Service class for managing knowledge base operations. - + This class extends CommonService to provide specialized functionality for knowledge base management, including document parsing status tracking, access control, and configuration management. It handles operations such as listing, creating, updating, and deleting knowledge bases, as well as managing their associated documents and permissions. - + The class implements a comprehensive set of methods for: - Document parsing status verification - Knowledge base access control - Parser configuration management - Tenant-based knowledge base organization - + Attributes: model: The Knowledgebase model class for database operations. """ @@ -42,22 +46,22 @@ class KnowledgebaseService(CommonService): @DB.connection_context() def accessible4deletion(cls, kb_id, user_id): """Check if a knowledge base can be deleted by a specific user. - + This method verifies whether a user has permission to delete a knowledge base by checking if they are the creator of that knowledge base. - + Args: kb_id (str): The unique identifier of the knowledge base to check. user_id (str): The unique identifier of the user attempting the deletion. - + Returns: bool: True if the user has permission to delete the knowledge base, False if the user doesn't have permission or the knowledge base doesn't exist. - + Example: >>> KnowledgebaseService.accessible4deletion("kb123", "user456") True - + Note: - This method only checks creator permissions - A return value of False can mean either: @@ -76,25 +80,25 @@ class KnowledgebaseService(CommonService): @DB.connection_context() def is_parsed_done(cls, kb_id): # Check if all documents in the knowledge base have completed parsing - # + # # Args: # kb_id: Knowledge base ID - # + # # Returns: # If all documents are parsed successfully, returns (True, None) # If any document is not fully parsed, returns (False, error_message) from api.db import TaskStatus from api.db.services.document_service import DocumentService - + # Get knowledge base information kbs = cls.query(id=kb_id) if not kbs: return False, "Knowledge base not found" kb = kbs[0] - + # Get all documents in the knowledge base docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "") - + # Check parsing status of each document for doc in docs: # If document is being parsed, don't allow chat creation @@ -103,21 +107,21 @@ class KnowledgebaseService(CommonService): # If document is not yet parsed and has no chunks, don't allow chat creation if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0: return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat." - + return True, None @classmethod @DB.connection_context() - def list_documents_by_ids(cls,kb_ids): + def list_documents_by_ids(cls, kb_ids): # Get document IDs associated with given knowledge base IDs # Args: # kb_ids: List of knowledge base IDs # Returns: # List of document IDs - doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where( + doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where( cls.model.id.in_(kb_ids) ) - doc_ids =list(doc_ids.dicts()) + doc_ids = list(doc_ids.dicts()) doc_ids = [doc["document_id"] for doc in doc_ids] return doc_ids @@ -222,7 +226,7 @@ class KnowledgebaseService(CommonService): cls.model.parser_config, cls.model.pagerank] kbs = cls.model.select(*fields).join(Tenant, on=( - (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( + (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( (cls.model.id == kb_id), (cls.model.status == StatusEnum.VALID.value) ) @@ -324,7 +328,7 @@ class KnowledgebaseService(CommonService): kbs = kbs.where( ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | ( - cls.model.tenant_id == user_id)) + cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value) ) if desc: @@ -347,7 +351,7 @@ class KnowledgebaseService(CommonService): # Boolean indicating accessibility docs = cls.model.select( cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) + ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) docs = docs.dicts() if not docs: return False @@ -363,7 +367,7 @@ class KnowledgebaseService(CommonService): # Returns: # List containing knowledge base information kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) + ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) kbs = kbs.dicts() return list(kbs) @@ -377,7 +381,16 @@ class KnowledgebaseService(CommonService): # Returns: # List containing knowledge base information kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) + ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) kbs = kbs.dicts() return list(kbs) + @classmethod + @DB.connection_context() + def atomic_increase_doc_num_by_id(cls, kb_id): + data = {} + data["update_time"] = current_timestamp() + data["update_date"] = datetime_format(datetime.now()) + data["doc_num"] = cls.model.doc_num + 1 + num = cls.model.update(data).where(cls.model.id == kb_id).execute() + return num