From 47926f7d2137cc3437d0d9dadb317d7113454112 Mon Sep 17 00:00:00 2001 From: Xinghan Pan Date: Thu, 13 Mar 2025 19:06:50 +0800 Subject: [PATCH] Improve API Documentation, Standardize Error Handling, and Enhance Comments (#5990) ### What problem does this PR solve? - The API documentation lacks detailed error code explanations. Added error code tables to `python_api_reference.md` and `http_api_reference.md` to clarify possible error codes and their meanings. - Error handling in the codebase is inconsistent. Standardized error handling logic in `sdk/python/ragflow_sdk/modules/chunk.py`. - Improved API comments by adding standardized docstrings to enhance code readability and maintainability. ### Type of change - [x] Documentation Update - [x] Refactoring --- api/db/services/common_service.py | 153 ++++++++++++++++++++++- api/db/services/file_service.py | 90 +++++++++++++ api/db/services/knowledgebase_service.py | 152 +++++++++++++++++++--- api/db/services/task_service.py | 126 +++++++++++++++++++ api/db/services/user_service.py | 41 ++++++ docs/references/http_api_reference.md | 20 +++ docs/references/python_api_reference.md | 21 +++- sdk/python/ragflow_sdk/modules/chunk.py | 12 +- 8 files changed, 591 insertions(+), 24 deletions(-) diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index dcbe28cda..184f30d5c 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -22,17 +22,56 @@ from api.utils import datetime_format, current_timestamp, get_uuid class CommonService: + """Base service class that provides common database operations. + + This class serves as a foundation for all service classes in the application, + implementing standard CRUD operations and common database query patterns. + It uses the Peewee ORM for database interactions and provides a consistent + interface for database operations across all derived service classes. + + Attributes: + model: The Peewee model class that this service operates on. Must be set by subclasses. + """ model = None @classmethod @DB.connection_context() def query(cls, cols=None, reverse=None, order_by=None, **kwargs): + """Execute a database query with optional column selection and ordering. + + This method provides a flexible way to query the database with various filters + and sorting options. It supports column selection, sort order control, and + additional filter conditions. + + Args: + cols (list, optional): List of column names to select. If None, selects all columns. + reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order. + order_by (str, optional): Column name to sort results by. + **kwargs: Additional filter conditions passed as keyword arguments. + + Returns: + peewee.ModelSelect: A query result containing matching records. + """ return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) @classmethod @DB.connection_context() def get_all(cls, cols=None, reverse=None, order_by=None): + """Retrieve all records from the database with optional column selection and ordering. + + This method fetches all records from the model's table with support for + column selection and result ordering. If no order_by is specified and reverse + is True, it defaults to ordering by create_time. + + Args: + cols (list, optional): List of column names to select. If None, selects all columns. + reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order. + order_by (str, optional): Column name to sort results by. Defaults to 'create_time' if reverse is specified. + + Returns: + peewee.ModelSelect: A query containing all matching records. + """ if cols: query_records = cls.model.select(*cols) else: @@ -51,11 +90,36 @@ class CommonService: @classmethod @DB.connection_context() def get(cls, **kwargs): + """Get a single record matching the given criteria. + + This method retrieves a single record from the database that matches + the specified filter conditions. + + Args: + **kwargs: Filter conditions as keyword arguments. + + Returns: + Model instance: Single matching record. + + Raises: + peewee.DoesNotExist: If no matching record is found. + """ return cls.model.get(**kwargs) @classmethod @DB.connection_context() def get_or_none(cls, **kwargs): + """Get a single record or None if not found. + + This method attempts to retrieve a single record matching the given criteria, + returning None if no match is found instead of raising an exception. + + Args: + **kwargs: Filter conditions as keyword arguments. + + Returns: + Model instance or None: Matching record if found, None otherwise. + """ try: return cls.model.get(**kwargs) except peewee.DoesNotExist: @@ -64,14 +128,34 @@ class CommonService: @classmethod @DB.connection_context() def save(cls, **kwargs): - # if "id" not in kwargs: - # kwargs["id"] = get_uuid() + """Save a new record to database. + + This method creates a new record in the database with the provided field values, + forcing an insert operation rather than an update. + + Args: + **kwargs: Record field values as keyword arguments. + + Returns: + Model instance: The created record object. + """ sample_obj = cls.model(**kwargs).save(force_insert=True) return sample_obj @classmethod @DB.connection_context() def insert(cls, **kwargs): + """Insert a new record with automatic ID and timestamps. + + This method creates a new record with automatically generated ID and timestamp fields. + It handles the creation of create_time, create_date, update_time, and update_date fields. + + Args: + **kwargs: Record field values as keyword arguments. + + Returns: + Model instance: The newly created record object. + """ if "id" not in kwargs: kwargs["id"] = get_uuid() kwargs["create_time"] = current_timestamp() @@ -84,6 +168,15 @@ class CommonService: @classmethod @DB.connection_context() def insert_many(cls, data_list, batch_size=100): + """Insert multiple records in batches. + + This method efficiently inserts multiple records into the database using batch processing. + It automatically sets creation timestamps for all records. + + Args: + data_list (list): List of dictionaries containing record data to insert. + batch_size (int, optional): Number of records to insert in each batch. Defaults to 100. + """ with DB.atomic(): for d in data_list: d["create_time"] = current_timestamp() @@ -94,6 +187,15 @@ class CommonService: @classmethod @DB.connection_context() def update_many_by_id(cls, data_list): + """Update multiple records by their IDs. + + This method updates multiple records in the database, identified by their IDs. + It automatically updates the update_time and update_date fields for each record. + + Args: + data_list (list): List of dictionaries containing record data to update. + Each dictionary must include an 'id' field. + """ with DB.atomic(): for data in data_list: data["update_time"] = current_timestamp() @@ -104,6 +206,12 @@ class CommonService: @classmethod @DB.connection_context() def update_by_id(cls, pid, data): + # Update a single record by ID + # Args: + # pid: Record ID + # data: Updated field values + # Returns: + # Number of records updated data["update_time"] = current_timestamp() data["update_date"] = datetime_format(datetime.now()) num = cls.model.update(data).where(cls.model.id == pid).execute() @@ -112,6 +220,11 @@ class CommonService: @classmethod @DB.connection_context() def get_by_id(cls, pid): + # Get a record by ID + # Args: + # pid: Record ID + # Returns: + # Tuple of (success, record) try: obj = cls.model.query(id=pid)[0] return True, obj @@ -121,6 +234,12 @@ class CommonService: @classmethod @DB.connection_context() def get_by_ids(cls, pids, cols=None): + # Get multiple records by their IDs + # Args: + # pids: List of record IDs + # cols: List of columns to select + # Returns: + # Query of matching records if cols: objs = cls.model.select(*cols) else: @@ -130,11 +249,21 @@ class CommonService: @classmethod @DB.connection_context() def delete_by_id(cls, pid): + # Delete a record by ID + # Args: + # pid: Record ID + # Returns: + # Number of records deleted return cls.model.delete().where(cls.model.id == pid).execute() @classmethod @DB.connection_context() def filter_delete(cls, filters): + # Delete records matching given filters + # Args: + # filters: List of filter conditions + # Returns: + # Number of records deleted with DB.atomic(): num = cls.model.delete().where(*filters).execute() return num @@ -142,11 +271,23 @@ class CommonService: @classmethod @DB.connection_context() def filter_update(cls, filters, update_data): + # Update records matching given filters + # Args: + # filters: List of filter conditions + # update_data: Updated field values + # Returns: + # Number of records updated with DB.atomic(): return cls.model.update(update_data).where(*filters).execute() @staticmethod def cut_list(tar_list, n): + # Split a list into chunks of size n + # Args: + # tar_list: List to split + # n: Chunk size + # Returns: + # List of tuples containing chunks length = len(tar_list) arr = range(length) result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]] @@ -156,6 +297,14 @@ class CommonService: @DB.connection_context() def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None): + # Get records matching IN clause filters with optional column selection + # Args: + # in_key: Field name for IN clause + # in_filters_list: List of values for IN clause + # filters: Additional filter conditions + # cols: List of columns to select + # Returns: + # List of matching records in_filters_tuple_list = cls.cut_list(in_filters_list, 20) if not filters: filters = [] diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index be6220bc9..367f39d82 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -34,12 +34,24 @@ from rag.utils.storage_factory import STORAGE_IMPL class FileService(CommonService): + # Service class for managing file operations and storage model = File @classmethod @DB.connection_context() def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords): + # Get files by parent folder ID with pagination and filtering + # Args: + # tenant_id: ID of the tenant + # pf_id: Parent folder ID + # page_number: Page number for pagination + # items_per_page: Number of items per page + # orderby: Field to order by + # desc: Boolean indicating descending order + # keywords: Search keywords + # Returns: + # Tuple of (file_list, total_count) if keywords: files = cls.model.select().where( (cls.model.tenant_id == tenant_id), @@ -80,6 +92,11 @@ class FileService(CommonService): @classmethod @DB.connection_context() def get_kb_id_by_file_id(cls, file_id): + # Get knowledge base IDs associated with a file + # Args: + # file_id: File ID + # Returns: + # List of dictionaries containing knowledge base IDs and names kbs = (cls.model.select(*[Knowledgebase.id, Knowledgebase.name]) .join(File2Document, on=(File2Document.file_id == file_id)) .join(Document, on=(File2Document.document_id == Document.id)) @@ -95,6 +112,12 @@ class FileService(CommonService): @classmethod @DB.connection_context() def get_by_pf_id_name(cls, id, name): + # Get file by parent folder ID and name + # Args: + # id: Parent folder ID + # name: File name + # Returns: + # File object or None if not found file = cls.model.select().where((cls.model.parent_id == id) & (cls.model.name == name)) if file.count(): e, file = cls.get_by_id(file[0].id) @@ -106,6 +129,14 @@ class FileService(CommonService): @classmethod @DB.connection_context() def get_id_list_by_id(cls, id, name, count, res): + # Recursively get list of file IDs by traversing folder structure + # Args: + # id: Starting folder ID + # name: List of folder names to traverse + # count: Current depth in traversal + # res: List to store results + # Returns: + # List of file IDs if count < len(name): file = cls.get_by_pf_id_name(id, name[count]) if file: @@ -119,6 +150,12 @@ class FileService(CommonService): @classmethod @DB.connection_context() def get_all_innermost_file_ids(cls, folder_id, result_ids): + # Get IDs of all files in the deepest level of folders + # Args: + # folder_id: Starting folder ID + # result_ids: List to store results + # Returns: + # List of file IDs subfolders = cls.model.select().where(cls.model.parent_id == folder_id) if subfolders.exists(): for subfolder in subfolders: @@ -130,6 +167,14 @@ class FileService(CommonService): @classmethod @DB.connection_context() def create_folder(cls, file, parent_id, name, count): + # Recursively create folder structure + # Args: + # file: Current file object + # parent_id: Parent folder ID + # name: List of folder names to create + # count: Current depth in creation + # Returns: + # Created file object if count > len(name) - 2: return file else: @@ -148,6 +193,11 @@ class FileService(CommonService): @classmethod @DB.connection_context() def is_parent_folder_exist(cls, parent_id): + # Check if parent folder exists + # Args: + # parent_id: Parent folder ID + # Returns: + # Boolean indicating if folder exists parent_files = cls.model.select().where(cls.model.id == parent_id) if parent_files.count(): return True @@ -157,6 +207,11 @@ class FileService(CommonService): @classmethod @DB.connection_context() def get_root_folder(cls, tenant_id): + # Get or create root folder for tenant + # Args: + # tenant_id: Tenant ID + # Returns: + # Root folder dictionary for file in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id) ): @@ -179,6 +234,11 @@ class FileService(CommonService): @classmethod @DB.connection_context() def get_kb_folder(cls, tenant_id): + # Get knowledge base folder for tenant + # Args: + # tenant_id: Tenant ID + # Returns: + # Knowledge base folder dictionary for root in cls.model.select().where( (cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)): for folder in cls.model.select().where( @@ -190,6 +250,16 @@ class FileService(CommonService): @classmethod @DB.connection_context() def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""): + # Create a new file from knowledge base + # Args: + # tenant_id: Tenant ID + # name: File name + # parent_id: Parent folder ID + # ty: File type + # size: File size + # location: File location + # Returns: + # Created file dictionary for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name): return file.to_dict() file = { @@ -209,6 +279,10 @@ class FileService(CommonService): @classmethod @DB.connection_context() def init_knowledgebase_docs(cls, root_id, tenant_id): + # Initialize knowledge base documents + # Args: + # root_id: Root folder ID + # tenant_id: Tenant ID for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\ & (cls.model.parent_id == root_id)): return @@ -222,6 +296,11 @@ class FileService(CommonService): @classmethod @DB.connection_context() def get_parent_folder(cls, file_id): + # Get parent folder of a file + # Args: + # file_id: File ID + # Returns: + # Parent folder object file = cls.model.select().where(cls.model.id == file_id) if file.count(): e, file = cls.get_by_id(file[0].parent_id) @@ -234,6 +313,11 @@ class FileService(CommonService): @classmethod @DB.connection_context() def get_all_parent_folders(cls, start_id): + # Get all parent folders in path + # Args: + # start_id: Starting file ID + # Returns: + # List of parent folder objects parent_folders = [] current_id = start_id while current_id: @@ -249,6 +333,11 @@ class FileService(CommonService): @classmethod @DB.connection_context() def insert(cls, file): + # Insert a new file record + # Args: + # file: File data dictionary + # Returns: + # Created file object if not cls.save(**file): raise RuntimeError("Database error (File)!") return File(**file) @@ -256,6 +345,7 @@ class FileService(CommonService): @classmethod @DB.connection_context() def delete(cls, file): + # return cls.delete_by_id(file.id) @classmethod diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index f4567cddf..b9fa56e03 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -20,21 +20,69 @@ from peewee import fn class KnowledgebaseService(CommonService): + """Service class for managing knowledge base operations. + + This class extends CommonService to provide specialized functionality for knowledge base + management, including document parsing status tracking, access control, and configuration + management. It handles operations such as listing, creating, updating, and deleting + knowledge bases, as well as managing their associated documents and permissions. + + The class implements a comprehensive set of methods for: + - Document parsing status verification + - Knowledge base access control + - Parser configuration management + - Tenant-based knowledge base organization + + Attributes: + model: The Knowledgebase model class for database operations. + """ model = Knowledgebase @classmethod @DB.connection_context() - def is_parsed_done(cls, kb_id): - """ - Check if all documents in the knowledge base have completed parsing - + def accessible4deletion(cls, kb_id, user_id): + """Check if a knowledge base can be deleted by a specific user. + + This method verifies whether a user has permission to delete a knowledge base + by checking if they are the creator of that knowledge base. + Args: - kb_id: Knowledge base ID - + kb_id (str): The unique identifier of the knowledge base to check. + user_id (str): The unique identifier of the user attempting the deletion. + Returns: - If all documents are parsed successfully, returns (True, None) - If any document is not fully parsed, returns (False, error_message) + bool: True if the user has permission to delete the knowledge base, + False if the user doesn't have permission or the knowledge base doesn't exist. + + Example: + >>> KnowledgebaseService.accessible4deletion("kb123", "user456") + True + + Note: + - This method only checks creator permissions + - A return value of False can mean either: + 1. The knowledge base doesn't exist + 2. The user is not the creator of the knowledge base """ + # Check if a knowledge base can be deleted by a user + docs = cls.model.select( + cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1) + docs = docs.dicts() + if not docs: + return False + return True + + @classmethod + @DB.connection_context() + def is_parsed_done(cls, kb_id): + # Check if all documents in the knowledge base have completed parsing + # + # Args: + # kb_id: Knowledge base ID + # + # Returns: + # If all documents are parsed successfully, returns (True, None) + # If any document is not fully parsed, returns (False, error_message) from api.db import TaskStatus from api.db.services.document_service import DocumentService @@ -61,6 +109,11 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def list_documents_by_ids(cls,kb_ids): + # Get document IDs associated with given knowledge base IDs + # Args: + # kb_ids: List of knowledge base IDs + # Returns: + # List of document IDs doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where( cls.model.id.in_(kb_ids) ) @@ -75,6 +128,18 @@ class KnowledgebaseService(CommonService): orderby, desc, keywords, parser_id=None ): + # Get knowledge bases by tenant IDs with pagination and filtering + # Args: + # joined_tenant_ids: List of tenant IDs + # user_id: Current user ID + # page_number: Page number for pagination + # items_per_page: Number of items per page + # orderby: Field to order by + # desc: Boolean indicating descending order + # keywords: Search keywords + # parser_id: Optional parser ID filter + # Returns: + # Tuple of (knowledge_base_list, total_count) fields = [ cls.model.id, cls.model.avatar, @@ -122,6 +187,11 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_kb_ids(cls, tenant_id): + # Get all knowledge base IDs for a tenant + # Args: + # tenant_id: Tenant ID + # Returns: + # List of knowledge base IDs fields = [ cls.model.id, ] @@ -132,9 +202,13 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_detail(cls, kb_id): + # Get detailed information about a knowledge base + # Args: + # kb_id: Knowledge base ID + # Returns: + # Dictionary containing knowledge base details fields = [ cls.model.id, - # Tenant.embd_id, cls.model.embd_id, cls.model.avatar, cls.model.name, @@ -155,17 +229,21 @@ class KnowledgebaseService(CommonService): if not kbs: return d = kbs[0].to_dict() - # d["embd_id"] = kbs[0].tenant.embd_id return d @classmethod @DB.connection_context() def update_parser_config(cls, id, config): + # Update parser configuration for a knowledge base + # Args: + # id: Knowledge base ID + # config: New parser configuration e, m = cls.get_by_id(id) if not e: raise LookupError(f"knowledgebase({id}) not found.") def dfs_update(old, new): + # Deep update of nested configuration for k, v in new.items(): if k not in old: old[k] = v @@ -185,6 +263,11 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_field_map(cls, ids): + # Get field mappings for knowledge bases + # Args: + # ids: List of knowledge base IDs + # Returns: + # Dictionary of field mappings conf = {} for k in cls.get_by_ids(ids): if k.parser_config and "field_map" in k.parser_config: @@ -194,6 +277,12 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_by_name(cls, kb_name, tenant_id): + # Get knowledge base by name and tenant ID + # Args: + # kb_name: Knowledge base name + # tenant_id: Tenant ID + # Returns: + # Tuple of (exists, knowledge_base) kb = cls.model.select().where( (cls.model.name == kb_name) & (cls.model.tenant_id == tenant_id) @@ -206,12 +295,27 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_all_ids(cls): + # Get all knowledge base IDs + # Returns: + # List of all knowledge base IDs return [m["id"] for m in cls.model.select(cls.model.id).dicts()] @classmethod @DB.connection_context() def get_list(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, id, name): + # Get list of knowledge bases with filtering and pagination + # Args: + # joined_tenant_ids: List of tenant IDs + # user_id: Current user ID + # page_number: Page number for pagination + # items_per_page: Number of items per page + # orderby: Field to order by + # desc: Boolean indicating descending order + # id: Optional ID filter + # name: Optional name filter + # Returns: + # List of knowledge bases kbs = cls.model.select() if id: kbs = kbs.where(cls.model.id == id) @@ -235,6 +339,12 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def accessible(cls, kb_id, user_id): + # Check if a knowledge base is accessible by a user + # Args: + # kb_id: Knowledge base ID + # user_id: User ID + # Returns: + # Boolean indicating accessibility docs = cls.model.select( cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) @@ -246,6 +356,12 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_kb_by_id(cls, kb_id, user_id): + # Get knowledge base by ID and user ID + # Args: + # kb_id: Knowledge base ID + # user_id: User ID + # Returns: + # List containing knowledge base information kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) kbs = kbs.dicts() @@ -254,18 +370,14 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_kb_by_name(cls, kb_name, user_id): + # Get knowledge base by name and user ID + # Args: + # kb_name: Knowledge base name + # user_id: User ID + # Returns: + # List containing knowledge base information kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) kbs = kbs.dicts() return list(kbs) - @classmethod - @DB.connection_context() - def accessible4deletion(cls, kb_id, user_id): - docs = cls.model.select( - cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1) - docs = docs.dicts() - if not docs: - return False - return True - diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 2d3147c65..d75954365 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -36,6 +36,12 @@ from rag.nlp import search def trim_header_by_lines(text: str, max_length) -> str: + # Trim header text to maximum length while preserving line breaks + # Args: + # text: Input text to trim + # max_length: Maximum allowed length + # Returns: + # Trimmed text len_text = len(text) if len_text <= max_length: return text @@ -46,11 +52,37 @@ def trim_header_by_lines(text: str, max_length) -> str: class TaskService(CommonService): + """Service class for managing document processing tasks. + + This class extends CommonService to provide specialized functionality for document + processing task management, including task creation, progress tracking, and chunk + management. It handles various document types (PDF, Excel, etc.) and manages their + processing lifecycle. + + The class implements a robust task queue system with retry mechanisms and progress + tracking, supporting both synchronous and asynchronous task execution. + + Attributes: + model: The Task model class for database operations. + """ model = Task @classmethod @DB.connection_context() def get_task(cls, task_id): + """Retrieve detailed task information by task ID. + + This method fetches comprehensive task details including associated document, + knowledge base, and tenant information. It also handles task retry logic and + progress updates. + + Args: + task_id (str): The unique identifier of the task to retrieve. + + Returns: + dict: Task details dictionary containing all task information and related metadata. + Returns None if task is not found or has exceeded retry limit. + """ fields = [ cls.model.id, cls.model.doc_id, @@ -105,6 +137,18 @@ class TaskService(CommonService): @classmethod @DB.connection_context() def get_tasks(cls, doc_id: str): + """Retrieve all tasks associated with a document. + + This method fetches all processing tasks for a given document, ordered by page + number and creation time. It includes task progress and chunk information. + + Args: + doc_id (str): The unique identifier of the document. + + Returns: + list[dict]: List of task dictionaries containing task details. + Returns None if no tasks are found. + """ fields = [ cls.model.id, cls.model.from_page, @@ -124,11 +168,31 @@ class TaskService(CommonService): @classmethod @DB.connection_context() def update_chunk_ids(cls, id: str, chunk_ids: str): + """Update the chunk IDs associated with a task. + + This method updates the chunk_ids field of a task, which stores the IDs of + processed document chunks in a space-separated string format. + + Args: + id (str): The unique identifier of the task. + chunk_ids (str): Space-separated string of chunk identifiers. + """ cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute() @classmethod @DB.connection_context() def get_ongoing_doc_name(cls): + """Get names of documents that are currently being processed. + + This method retrieves information about documents that are in the processing state, + including their locations and associated IDs. It uses database locking to ensure + thread safety when accessing the task information. + + Returns: + list[tuple]: A list of tuples, each containing (parent_id/kb_id, location) + for documents currently being processed. Returns empty list if + no documents are being processed. + """ with DB.lock("get_task", -1): docs = ( cls.model.select( @@ -172,6 +236,18 @@ class TaskService(CommonService): @classmethod @DB.connection_context() def do_cancel(cls, id): + """Check if a task should be cancelled based on its document status. + + This method determines whether a task should be cancelled by checking the + associated document's run status and progress. A task should be cancelled + if its document is marked for cancellation or has negative progress. + + Args: + id (str): The unique identifier of the task to check. + + Returns: + bool: True if the task should be cancelled, False otherwise. + """ task = cls.model.get_by_id(id) _, doc = DocumentService.get_by_id(task.doc_id) return doc.run == TaskStatus.CANCEL.value or doc.progress < 0 @@ -179,6 +255,18 @@ class TaskService(CommonService): @classmethod @DB.connection_context() def update_progress(cls, id, info): + """Update the progress information for a task. + + This method updates both the progress message and completion percentage of a task. + It handles platform-specific behavior (macOS vs others) and uses database locking + when necessary to ensure thread safety. + + Args: + id (str): The unique identifier of the task to update. + info (dict): Dictionary containing progress information with keys: + - progress_msg (str, optional): Progress message to append + - progress (float, optional): Progress percentage (0.0 to 1.0) + """ if os.environ.get("MACOS"): if info["progress_msg"]: task = cls.model.get_by_id(id) @@ -202,6 +290,24 @@ class TaskService(CommonService): def queue_tasks(doc: dict, bucket: str, name: str): + """Create and queue document processing tasks. + + This function creates processing tasks for a document based on its type and configuration. + It handles different document types (PDF, Excel, etc.) differently and manages task + chunking and configuration. It also implements task reuse optimization by checking + for previously completed tasks. + + Args: + doc (dict): Document dictionary containing metadata and configuration. + bucket (str): Storage bucket name where the document is stored. + name (str): File name of the document. + + Note: + - For PDF documents, tasks are created per page range based on configuration + - For Excel documents, tasks are created per row range + - Task digests are calculated for optimization and reuse + - Previous task chunks may be reused if available + """ def new_task(): return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000} @@ -279,6 +385,26 @@ def queue_tasks(doc: dict, bucket: str, name: str): def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict): + """Attempt to reuse chunks from previous tasks for optimization. + + This function checks if chunks from previously completed tasks can be reused for + the current task, which can significantly improve processing efficiency. It matches + tasks based on page ranges and configuration digests. + + Args: + task (dict): Current task dictionary to potentially reuse chunks for. + prev_tasks (list[dict]): List of previous task dictionaries to check for reuse. + chunking_config (dict): Configuration dictionary for chunk processing. + + Returns: + int: Number of chunks successfully reused. Returns 0 if no chunks could be reused. + + Note: + Chunks can only be reused if: + - A previous task exists with matching page range and configuration digest + - The previous task was completed successfully (progress = 1.0) + - The previous task has valid chunk IDs + """ idx = 0 while idx < len(prev_tasks): prev_task = prev_tasks[idx] diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 56a4352ce..facab201c 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -29,11 +29,27 @@ from rag.settings import MINIO class UserService(CommonService): + """Service class for managing user-related database operations. + + This class extends CommonService to provide specialized functionality for user management, + including authentication, user creation, updates, and deletions. + + Attributes: + model: The User model class for database operations. + """ model = User @classmethod @DB.connection_context() def filter_by_id(cls, user_id): + """Retrieve a user by their ID. + + Args: + user_id: The unique identifier of the user. + + Returns: + User object if found, None otherwise. + """ try: user = cls.model.select().where(cls.model.id == user_id).get() return user @@ -43,6 +59,15 @@ class UserService(CommonService): @classmethod @DB.connection_context() def query_user(cls, email, password): + """Authenticate a user with email and password. + + Args: + email: User's email address. + password: User's password in plain text. + + Returns: + User object if authentication successful, None otherwise. + """ user = cls.model.select().where((cls.model.email == email), (cls.model.status == StatusEnum.VALID.value)).first() if user and check_password_hash(str(user.password), password): @@ -85,6 +110,14 @@ class UserService(CommonService): class TenantService(CommonService): + """Service class for managing tenant-related database operations. + + This class extends CommonService to provide functionality for tenant management, + including tenant information retrieval and credit management. + + Attributes: + model: The Tenant model class for database operations. + """ model = Tenant @classmethod @@ -136,6 +169,14 @@ class TenantService(CommonService): class UserTenantService(CommonService): + """Service class for managing user-tenant relationship operations. + + This class extends CommonService to handle the many-to-many relationship + between users and tenants, managing user roles and tenant memberships. + + Attributes: + model: The UserTenant model class for database operations. + """ model = UserTenant @classmethod diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 216a1c713..75df94a77 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -531,6 +531,26 @@ Failure: --- +## Error Codes + +--- + +| Code | Message | Description | +|------|---------|-------------| +| 400 | Bad Request | Invalid request parameters | +| 401 | Unauthorized | Unauthorized access | +| 403 | Forbidden | Access denied | +| 404 | Not Found | Resource not found | +| 500 | Internal Server Error | Server internal error | +| 1001 | Invalid Chunk ID | Invalid Chunk ID | +| 1002 | Chunk Update Failed | Chunk update failed | + + + +--- + +--- + ## FILE MANAGEMENT WITHIN DATASET --- diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index 533ef8020..b3c9e3dc6 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -317,6 +317,23 @@ dataset = rag_object.list_datasets(name="kb_name") dataset.update({"embedding_model":"BAAI/bge-zh-v1.5", "chunk_method":"manual"}) ``` +--- + +## Error Codes + +--- + +| Code | Message | Description | +|------|---------|-------------| +| 400 | Bad Request | Invalid request parameters | +| 401 | Unauthorized | Unauthorized access | +| 403 | Forbidden | Access denied | +| 404 | Not Found | Resource not found | +| 500 | Internal Server Error | Server internal error | +| 1001 | Invalid Chunk ID | Invalid Chunk ID | +| 1002 | Chunk Update Failed | Chunk update failed | + + --- ## FILE MANAGEMENT WITHIN DATASET @@ -1719,4 +1736,6 @@ for agent in rag_object.list_agents(): print(agent) ``` ---- \ No newline at end of file +--- + + diff --git a/sdk/python/ragflow_sdk/modules/chunk.py b/sdk/python/ragflow_sdk/modules/chunk.py index 4471b58af..f943b8865 100644 --- a/sdk/python/ragflow_sdk/modules/chunk.py +++ b/sdk/python/ragflow_sdk/modules/chunk.py @@ -16,6 +16,12 @@ from .base import Base +class ChunkUpdateError(Exception): + def __init__(self, code=None, message=None, details=None): + self.code = code + self.message = message + self.details = details + super().__init__(message) class Chunk(Base): def __init__(self, rag, res_dict): @@ -38,4 +44,8 @@ class Chunk(Base): res = self.put(f"/datasets/{self.dataset_id}/documents/{self.document_id}/chunks/{self.id}", update_message) res = res.json() if res.get("code") != 0: - raise Exception(res["message"]) + raise ChunkUpdateError( + code=res.get("code"), + message=res.get("message"), + details=res.get("details") + ) \ No newline at end of file