Improve API Documentation, Standardize Error Handling, and Enhance Comments (#5990)

### What problem does this PR solve?  
- The API documentation lacks detailed error code explanations. Added
error code tables to `python_api_reference.md` and
`http_api_reference.md` to clarify possible error codes and their
meanings.
- Error handling in the codebase is inconsistent. Standardized error
handling logic in `sdk/python/ragflow_sdk/modules/chunk.py`.
- Improved API comments by adding standardized docstrings to enhance
code readability and maintainability.

### Type of change  
- [x] Documentation Update  
- [x] Refactoring
This commit is contained in:
Xinghan Pan 2025-03-13 19:06:50 +08:00 committed by GitHub
parent 940072592f
commit 47926f7d21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 591 additions and 24 deletions

View File

@ -22,17 +22,56 @@ from api.utils import datetime_format, current_timestamp, get_uuid
class CommonService:
"""Base service class that provides common database operations.
This class serves as a foundation for all service classes in the application,
implementing standard CRUD operations and common database query patterns.
It uses the Peewee ORM for database interactions and provides a consistent
interface for database operations across all derived service classes.
Attributes:
model: The Peewee model class that this service operates on. Must be set by subclasses.
"""
model = None
@classmethod
@DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
"""Execute a database query with optional column selection and ordering.
This method provides a flexible way to query the database with various filters
and sorting options. It supports column selection, sort order control, and
additional filter conditions.
Args:
cols (list, optional): List of column names to select. If None, selects all columns.
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
order_by (str, optional): Column name to sort results by.
**kwargs: Additional filter conditions passed as keyword arguments.
Returns:
peewee.ModelSelect: A query result containing matching records.
"""
return cls.model.query(cols=cols, reverse=reverse,
order_by=order_by, **kwargs)
@classmethod
@DB.connection_context()
def get_all(cls, cols=None, reverse=None, order_by=None):
"""Retrieve all records from the database with optional column selection and ordering.
This method fetches all records from the model's table with support for
column selection and result ordering. If no order_by is specified and reverse
is True, it defaults to ordering by create_time.
Args:
cols (list, optional): List of column names to select. If None, selects all columns.
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
order_by (str, optional): Column name to sort results by. Defaults to 'create_time' if reverse is specified.
Returns:
peewee.ModelSelect: A query containing all matching records.
"""
if cols:
query_records = cls.model.select(*cols)
else:
@ -51,11 +90,36 @@ class CommonService:
@classmethod
@DB.connection_context()
def get(cls, **kwargs):
"""Get a single record matching the given criteria.
This method retrieves a single record from the database that matches
the specified filter conditions.
Args:
**kwargs: Filter conditions as keyword arguments.
Returns:
Model instance: Single matching record.
Raises:
peewee.DoesNotExist: If no matching record is found.
"""
return cls.model.get(**kwargs)
@classmethod
@DB.connection_context()
def get_or_none(cls, **kwargs):
"""Get a single record or None if not found.
This method attempts to retrieve a single record matching the given criteria,
returning None if no match is found instead of raising an exception.
Args:
**kwargs: Filter conditions as keyword arguments.
Returns:
Model instance or None: Matching record if found, None otherwise.
"""
try:
return cls.model.get(**kwargs)
except peewee.DoesNotExist:
@ -64,14 +128,34 @@ class CommonService:
@classmethod
@DB.connection_context()
def save(cls, **kwargs):
# if "id" not in kwargs:
# kwargs["id"] = get_uuid()
"""Save a new record to database.
This method creates a new record in the database with the provided field values,
forcing an insert operation rather than an update.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The created record object.
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def insert(cls, **kwargs):
"""Insert a new record with automatic ID and timestamps.
This method creates a new record with automatically generated ID and timestamp fields.
It handles the creation of create_time, create_date, update_time, and update_date fields.
Args:
**kwargs: Record field values as keyword arguments.
Returns:
Model instance: The newly created record object.
"""
if "id" not in kwargs:
kwargs["id"] = get_uuid()
kwargs["create_time"] = current_timestamp()
@ -84,6 +168,15 @@ class CommonService:
@classmethod
@DB.connection_context()
def insert_many(cls, data_list, batch_size=100):
"""Insert multiple records in batches.
This method efficiently inserts multiple records into the database using batch processing.
It automatically sets creation timestamps for all records.
Args:
data_list (list): List of dictionaries containing record data to insert.
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
"""
with DB.atomic():
for d in data_list:
d["create_time"] = current_timestamp()
@ -94,6 +187,15 @@ class CommonService:
@classmethod
@DB.connection_context()
def update_many_by_id(cls, data_list):
"""Update multiple records by their IDs.
This method updates multiple records in the database, identified by their IDs.
It automatically updates the update_time and update_date fields for each record.
Args:
data_list (list): List of dictionaries containing record data to update.
Each dictionary must include an 'id' field.
"""
with DB.atomic():
for data in data_list:
data["update_time"] = current_timestamp()
@ -104,6 +206,12 @@ class CommonService:
@classmethod
@DB.connection_context()
def update_by_id(cls, pid, data):
# Update a single record by ID
# Args:
# pid: Record ID
# data: Updated field values
# Returns:
# Number of records updated
data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now())
num = cls.model.update(data).where(cls.model.id == pid).execute()
@ -112,6 +220,11 @@ class CommonService:
@classmethod
@DB.connection_context()
def get_by_id(cls, pid):
# Get a record by ID
# Args:
# pid: Record ID
# Returns:
# Tuple of (success, record)
try:
obj = cls.model.query(id=pid)[0]
return True, obj
@ -121,6 +234,12 @@ class CommonService:
@classmethod
@DB.connection_context()
def get_by_ids(cls, pids, cols=None):
# Get multiple records by their IDs
# Args:
# pids: List of record IDs
# cols: List of columns to select
# Returns:
# Query of matching records
if cols:
objs = cls.model.select(*cols)
else:
@ -130,11 +249,21 @@ class CommonService:
@classmethod
@DB.connection_context()
def delete_by_id(cls, pid):
# Delete a record by ID
# Args:
# pid: Record ID
# Returns:
# Number of records deleted
return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod
@DB.connection_context()
def filter_delete(cls, filters):
# Delete records matching given filters
# Args:
# filters: List of filter conditions
# Returns:
# Number of records deleted
with DB.atomic():
num = cls.model.delete().where(*filters).execute()
return num
@ -142,11 +271,23 @@ class CommonService:
@classmethod
@DB.connection_context()
def filter_update(cls, filters, update_data):
# Update records matching given filters
# Args:
# filters: List of filter conditions
# update_data: Updated field values
# Returns:
# Number of records updated
with DB.atomic():
return cls.model.update(update_data).where(*filters).execute()
@staticmethod
def cut_list(tar_list, n):
# Split a list into chunks of size n
# Args:
# tar_list: List to split
# n: Chunk size
# Returns:
# List of tuples containing chunks
length = len(tar_list)
arr = range(length)
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
@ -156,6 +297,14 @@ class CommonService:
@DB.connection_context()
def filter_scope_list(cls, in_key, in_filters_list,
filters=None, cols=None):
# Get records matching IN clause filters with optional column selection
# Args:
# in_key: Field name for IN clause
# in_filters_list: List of values for IN clause
# filters: Additional filter conditions
# cols: List of columns to select
# Returns:
# List of matching records
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters:
filters = []

View File

@ -34,12 +34,24 @@ from rag.utils.storage_factory import STORAGE_IMPL
class FileService(CommonService):
# Service class for managing file operations and storage
model = File
@classmethod
@DB.connection_context()
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page,
orderby, desc, keywords):
# Get files by parent folder ID with pagination and filtering
# Args:
# tenant_id: ID of the tenant
# pf_id: Parent folder ID
# page_number: Page number for pagination
# items_per_page: Number of items per page
# orderby: Field to order by
# desc: Boolean indicating descending order
# keywords: Search keywords
# Returns:
# Tuple of (file_list, total_count)
if keywords:
files = cls.model.select().where(
(cls.model.tenant_id == tenant_id),
@ -80,6 +92,11 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_kb_id_by_file_id(cls, file_id):
# Get knowledge base IDs associated with a file
# Args:
# file_id: File ID
# Returns:
# List of dictionaries containing knowledge base IDs and names
kbs = (cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
.join(File2Document, on=(File2Document.file_id == file_id))
.join(Document, on=(File2Document.document_id == Document.id))
@ -95,6 +112,12 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_by_pf_id_name(cls, id, name):
# Get file by parent folder ID and name
# Args:
# id: Parent folder ID
# name: File name
# Returns:
# File object or None if not found
file = cls.model.select().where((cls.model.parent_id == id) & (cls.model.name == name))
if file.count():
e, file = cls.get_by_id(file[0].id)
@ -106,6 +129,14 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_id_list_by_id(cls, id, name, count, res):
# Recursively get list of file IDs by traversing folder structure
# Args:
# id: Starting folder ID
# name: List of folder names to traverse
# count: Current depth in traversal
# res: List to store results
# Returns:
# List of file IDs
if count < len(name):
file = cls.get_by_pf_id_name(id, name[count])
if file:
@ -119,6 +150,12 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_all_innermost_file_ids(cls, folder_id, result_ids):
# Get IDs of all files in the deepest level of folders
# Args:
# folder_id: Starting folder ID
# result_ids: List to store results
# Returns:
# List of file IDs
subfolders = cls.model.select().where(cls.model.parent_id == folder_id)
if subfolders.exists():
for subfolder in subfolders:
@ -130,6 +167,14 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def create_folder(cls, file, parent_id, name, count):
# Recursively create folder structure
# Args:
# file: Current file object
# parent_id: Parent folder ID
# name: List of folder names to create
# count: Current depth in creation
# Returns:
# Created file object
if count > len(name) - 2:
return file
else:
@ -148,6 +193,11 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def is_parent_folder_exist(cls, parent_id):
# Check if parent folder exists
# Args:
# parent_id: Parent folder ID
# Returns:
# Boolean indicating if folder exists
parent_files = cls.model.select().where(cls.model.id == parent_id)
if parent_files.count():
return True
@ -157,6 +207,11 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_root_folder(cls, tenant_id):
# Get or create root folder for tenant
# Args:
# tenant_id: Tenant ID
# Returns:
# Root folder dictionary
for file in cls.model.select().where((cls.model.tenant_id == tenant_id),
(cls.model.parent_id == cls.model.id)
):
@ -179,6 +234,11 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_kb_folder(cls, tenant_id):
# Get knowledge base folder for tenant
# Args:
# tenant_id: Tenant ID
# Returns:
# Knowledge base folder dictionary
for root in cls.model.select().where(
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
for folder in cls.model.select().where(
@ -190,6 +250,16 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
# Create a new file from knowledge base
# Args:
# tenant_id: Tenant ID
# name: File name
# parent_id: Parent folder ID
# ty: File type
# size: File size
# location: File location
# Returns:
# Created file dictionary
for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name):
return file.to_dict()
file = {
@ -209,6 +279,10 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def init_knowledgebase_docs(cls, root_id, tenant_id):
# Initialize knowledge base documents
# Args:
# root_id: Root folder ID
# tenant_id: Tenant ID
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\
& (cls.model.parent_id == root_id)):
return
@ -222,6 +296,11 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_parent_folder(cls, file_id):
# Get parent folder of a file
# Args:
# file_id: File ID
# Returns:
# Parent folder object
file = cls.model.select().where(cls.model.id == file_id)
if file.count():
e, file = cls.get_by_id(file[0].parent_id)
@ -234,6 +313,11 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def get_all_parent_folders(cls, start_id):
# Get all parent folders in path
# Args:
# start_id: Starting file ID
# Returns:
# List of parent folder objects
parent_folders = []
current_id = start_id
while current_id:
@ -249,6 +333,11 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def insert(cls, file):
# Insert a new file record
# Args:
# file: File data dictionary
# Returns:
# Created file object
if not cls.save(**file):
raise RuntimeError("Database error (File)!")
return File(**file)
@ -256,6 +345,7 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def delete(cls, file):
#
return cls.delete_by_id(file.id)
@classmethod

View File

@ -20,21 +20,69 @@ from peewee import fn
class KnowledgebaseService(CommonService):
"""Service class for managing knowledge base operations.
This class extends CommonService to provide specialized functionality for knowledge base
management, including document parsing status tracking, access control, and configuration
management. It handles operations such as listing, creating, updating, and deleting
knowledge bases, as well as managing their associated documents and permissions.
The class implements a comprehensive set of methods for:
- Document parsing status verification
- Knowledge base access control
- Parser configuration management
- Tenant-based knowledge base organization
Attributes:
model: The Knowledgebase model class for database operations.
"""
model = Knowledgebase
@classmethod
@DB.connection_context()
def is_parsed_done(cls, kb_id):
"""
Check if all documents in the knowledge base have completed parsing
def accessible4deletion(cls, kb_id, user_id):
"""Check if a knowledge base can be deleted by a specific user.
This method verifies whether a user has permission to delete a knowledge base
by checking if they are the creator of that knowledge base.
Args:
kb_id: Knowledge base ID
kb_id (str): The unique identifier of the knowledge base to check.
user_id (str): The unique identifier of the user attempting the deletion.
Returns:
If all documents are parsed successfully, returns (True, None)
If any document is not fully parsed, returns (False, error_message)
bool: True if the user has permission to delete the knowledge base,
False if the user doesn't have permission or the knowledge base doesn't exist.
Example:
>>> KnowledgebaseService.accessible4deletion("kb123", "user456")
True
Note:
- This method only checks creator permissions
- A return value of False can mean either:
1. The knowledge base doesn't exist
2. The user is not the creator of the knowledge base
"""
# Check if a knowledge base can be deleted by a user
docs = cls.model.select(
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True
@classmethod
@DB.connection_context()
def is_parsed_done(cls, kb_id):
# Check if all documents in the knowledge base have completed parsing
#
# Args:
# kb_id: Knowledge base ID
#
# Returns:
# If all documents are parsed successfully, returns (True, None)
# If any document is not fully parsed, returns (False, error_message)
from api.db import TaskStatus
from api.db.services.document_service import DocumentService
@ -61,6 +109,11 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def list_documents_by_ids(cls,kb_ids):
# Get document IDs associated with given knowledge base IDs
# Args:
# kb_ids: List of knowledge base IDs
# Returns:
# List of document IDs
doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where(
cls.model.id.in_(kb_ids)
)
@ -75,6 +128,18 @@ class KnowledgebaseService(CommonService):
orderby, desc, keywords,
parser_id=None
):
# Get knowledge bases by tenant IDs with pagination and filtering
# Args:
# joined_tenant_ids: List of tenant IDs
# user_id: Current user ID
# page_number: Page number for pagination
# items_per_page: Number of items per page
# orderby: Field to order by
# desc: Boolean indicating descending order
# keywords: Search keywords
# parser_id: Optional parser ID filter
# Returns:
# Tuple of (knowledge_base_list, total_count)
fields = [
cls.model.id,
cls.model.avatar,
@ -122,6 +187,11 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_kb_ids(cls, tenant_id):
# Get all knowledge base IDs for a tenant
# Args:
# tenant_id: Tenant ID
# Returns:
# List of knowledge base IDs
fields = [
cls.model.id,
]
@ -132,9 +202,13 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
# Get detailed information about a knowledge base
# Args:
# kb_id: Knowledge base ID
# Returns:
# Dictionary containing knowledge base details
fields = [
cls.model.id,
# Tenant.embd_id,
cls.model.embd_id,
cls.model.avatar,
cls.model.name,
@ -155,17 +229,21 @@ class KnowledgebaseService(CommonService):
if not kbs:
return
d = kbs[0].to_dict()
# d["embd_id"] = kbs[0].tenant.embd_id
return d
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
# Update parser configuration for a knowledge base
# Args:
# id: Knowledge base ID
# config: New parser configuration
e, m = cls.get_by_id(id)
if not e:
raise LookupError(f"knowledgebase({id}) not found.")
def dfs_update(old, new):
# Deep update of nested configuration
for k, v in new.items():
if k not in old:
old[k] = v
@ -185,6 +263,11 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_field_map(cls, ids):
# Get field mappings for knowledge bases
# Args:
# ids: List of knowledge base IDs
# Returns:
# Dictionary of field mappings
conf = {}
for k in cls.get_by_ids(ids):
if k.parser_config and "field_map" in k.parser_config:
@ -194,6 +277,12 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_by_name(cls, kb_name, tenant_id):
# Get knowledge base by name and tenant ID
# Args:
# kb_name: Knowledge base name
# tenant_id: Tenant ID
# Returns:
# Tuple of (exists, knowledge_base)
kb = cls.model.select().where(
(cls.model.name == kb_name)
& (cls.model.tenant_id == tenant_id)
@ -206,12 +295,27 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_all_ids(cls):
# Get all knowledge base IDs
# Returns:
# List of all knowledge base IDs
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
@classmethod
@DB.connection_context()
def get_list(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc, id, name):
# Get list of knowledge bases with filtering and pagination
# Args:
# joined_tenant_ids: List of tenant IDs
# user_id: Current user ID
# page_number: Page number for pagination
# items_per_page: Number of items per page
# orderby: Field to order by
# desc: Boolean indicating descending order
# id: Optional ID filter
# name: Optional name filter
# Returns:
# List of knowledge bases
kbs = cls.model.select()
if id:
kbs = kbs.where(cls.model.id == id)
@ -235,6 +339,12 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def accessible(cls, kb_id, user_id):
# Check if a knowledge base is accessible by a user
# Args:
# kb_id: Knowledge base ID
# user_id: User ID
# Returns:
# Boolean indicating accessibility
docs = cls.model.select(
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
@ -246,6 +356,12 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_kb_by_id(cls, kb_id, user_id):
# Get knowledge base by ID and user ID
# Args:
# kb_id: Knowledge base ID
# user_id: User ID
# Returns:
# List containing knowledge base information
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
kbs = kbs.dicts()
@ -254,18 +370,14 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_kb_by_name(cls, kb_name, user_id):
# Get knowledge base by name and user ID
# Args:
# kb_name: Knowledge base name
# user_id: User ID
# Returns:
# List containing knowledge base information
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
kbs = kbs.dicts()
return list(kbs)
@classmethod
@DB.connection_context()
def accessible4deletion(cls, kb_id, user_id):
docs = cls.model.select(
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

View File

@ -36,6 +36,12 @@ from rag.nlp import search
def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks
# Args:
# text: Input text to trim
# max_length: Maximum allowed length
# Returns:
# Trimmed text
len_text = len(text)
if len_text <= max_length:
return text
@ -46,11 +52,37 @@ def trim_header_by_lines(text: str, max_length) -> str:
class TaskService(CommonService):
"""Service class for managing document processing tasks.
This class extends CommonService to provide specialized functionality for document
processing task management, including task creation, progress tracking, and chunk
management. It handles various document types (PDF, Excel, etc.) and manages their
processing lifecycle.
The class implements a robust task queue system with retry mechanisms and progress
tracking, supporting both synchronous and asynchronous task execution.
Attributes:
model: The Task model class for database operations.
"""
model = Task
@classmethod
@DB.connection_context()
def get_task(cls, task_id):
"""Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document,
knowledge base, and tenant information. It also handles task retry logic and
progress updates.
Args:
task_id (str): The unique identifier of the task to retrieve.
Returns:
dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit.
"""
fields = [
cls.model.id,
cls.model.doc_id,
@ -105,6 +137,18 @@ class TaskService(CommonService):
@classmethod
@DB.connection_context()
def get_tasks(cls, doc_id: str):
"""Retrieve all tasks associated with a document.
This method fetches all processing tasks for a given document, ordered by page
number and creation time. It includes task progress and chunk information.
Args:
doc_id (str): The unique identifier of the document.
Returns:
list[dict]: List of task dictionaries containing task details.
Returns None if no tasks are found.
"""
fields = [
cls.model.id,
cls.model.from_page,
@ -124,11 +168,31 @@ class TaskService(CommonService):
@classmethod
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
"""Update the chunk IDs associated with a task.
This method updates the chunk_ids field of a task, which stores the IDs of
processed document chunks in a space-separated string format.
Args:
id (str): The unique identifier of the task.
chunk_ids (str): Space-separated string of chunk identifiers.
"""
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
"""Get names of documents that are currently being processed.
This method retrieves information about documents that are in the processing state,
including their locations and associated IDs. It uses database locking to ensure
thread safety when accessing the task information.
Returns:
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
for documents currently being processed. Returns empty list if
no documents are being processed.
"""
with DB.lock("get_task", -1):
docs = (
cls.model.select(
@ -172,6 +236,18 @@ class TaskService(CommonService):
@classmethod
@DB.connection_context()
def do_cancel(cls, id):
"""Check if a task should be cancelled based on its document status.
This method determines whether a task should be cancelled by checking the
associated document's run status and progress. A task should be cancelled
if its document is marked for cancellation or has negative progress.
Args:
id (str): The unique identifier of the task to check.
Returns:
bool: True if the task should be cancelled, False otherwise.
"""
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
@ -179,6 +255,18 @@ class TaskService(CommonService):
@classmethod
@DB.connection_context()
def update_progress(cls, id, info):
"""Update the progress information for a task.
This method updates both the progress message and completion percentage of a task.
It handles platform-specific behavior (macOS vs others) and uses database locking
when necessary to ensure thread safety.
Args:
id (str): The unique identifier of the task to update.
info (dict): Dictionary containing progress information with keys:
- progress_msg (str, optional): Progress message to append
- progress (float, optional): Progress percentage (0.0 to 1.0)
"""
if os.environ.get("MACOS"):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
@ -202,6 +290,24 @@ class TaskService(CommonService):
def queue_tasks(doc: dict, bucket: str, name: str):
"""Create and queue document processing tasks.
This function creates processing tasks for a document based on its type and configuration.
It handles different document types (PDF, Excel, etc.) differently and manages task
chunking and configuration. It also implements task reuse optimization by checking
for previously completed tasks.
Args:
doc (dict): Document dictionary containing metadata and configuration.
bucket (str): Storage bucket name where the document is stored.
name (str): File name of the document.
Note:
- For PDF documents, tasks are created per page range based on configuration
- For Excel documents, tasks are created per row range
- Task digests are calculated for optimization and reuse
- Previous task chunks may be reused if available
"""
def new_task():
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
@ -279,6 +385,26 @@ def queue_tasks(doc: dict, bucket: str, name: str):
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
"""Attempt to reuse chunks from previous tasks for optimization.
This function checks if chunks from previously completed tasks can be reused for
the current task, which can significantly improve processing efficiency. It matches
tasks based on page ranges and configuration digests.
Args:
task (dict): Current task dictionary to potentially reuse chunks for.
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
chunking_config (dict): Configuration dictionary for chunk processing.
Returns:
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
Note:
Chunks can only be reused if:
- A previous task exists with matching page range and configuration digest
- The previous task was completed successfully (progress = 1.0)
- The previous task has valid chunk IDs
"""
idx = 0
while idx < len(prev_tasks):
prev_task = prev_tasks[idx]

View File

@ -29,11 +29,27 @@ from rag.settings import MINIO
class UserService(CommonService):
"""Service class for managing user-related database operations.
This class extends CommonService to provide specialized functionality for user management,
including authentication, user creation, updates, and deletions.
Attributes:
model: The User model class for database operations.
"""
model = User
@classmethod
@DB.connection_context()
def filter_by_id(cls, user_id):
"""Retrieve a user by their ID.
Args:
user_id: The unique identifier of the user.
Returns:
User object if found, None otherwise.
"""
try:
user = cls.model.select().where(cls.model.id == user_id).get()
return user
@ -43,6 +59,15 @@ class UserService(CommonService):
@classmethod
@DB.connection_context()
def query_user(cls, email, password):
"""Authenticate a user with email and password.
Args:
email: User's email address.
password: User's password in plain text.
Returns:
User object if authentication successful, None otherwise.
"""
user = cls.model.select().where((cls.model.email == email),
(cls.model.status == StatusEnum.VALID.value)).first()
if user and check_password_hash(str(user.password), password):
@ -85,6 +110,14 @@ class UserService(CommonService):
class TenantService(CommonService):
"""Service class for managing tenant-related database operations.
This class extends CommonService to provide functionality for tenant management,
including tenant information retrieval and credit management.
Attributes:
model: The Tenant model class for database operations.
"""
model = Tenant
@classmethod
@ -136,6 +169,14 @@ class TenantService(CommonService):
class UserTenantService(CommonService):
"""Service class for managing user-tenant relationship operations.
This class extends CommonService to handle the many-to-many relationship
between users and tenants, managing user roles and tenant memberships.
Attributes:
model: The UserTenant model class for database operations.
"""
model = UserTenant
@classmethod

View File

@ -531,6 +531,26 @@ Failure:
---
## Error Codes
---
| Code | Message | Description |
|------|---------|-------------|
| 400 | Bad Request | Invalid request parameters |
| 401 | Unauthorized | Unauthorized access |
| 403 | Forbidden | Access denied |
| 404 | Not Found | Resource not found |
| 500 | Internal Server Error | Server internal error |
| 1001 | Invalid Chunk ID | Invalid Chunk ID |
| 1002 | Chunk Update Failed | Chunk update failed |
---
---
## FILE MANAGEMENT WITHIN DATASET
---

View File

@ -317,6 +317,23 @@ dataset = rag_object.list_datasets(name="kb_name")
dataset.update({"embedding_model":"BAAI/bge-zh-v1.5", "chunk_method":"manual"})
```
---
## Error Codes
---
| Code | Message | Description |
|------|---------|-------------|
| 400 | Bad Request | Invalid request parameters |
| 401 | Unauthorized | Unauthorized access |
| 403 | Forbidden | Access denied |
| 404 | Not Found | Resource not found |
| 500 | Internal Server Error | Server internal error |
| 1001 | Invalid Chunk ID | Invalid Chunk ID |
| 1002 | Chunk Update Failed | Chunk update failed |
---
## FILE MANAGEMENT WITHIN DATASET
@ -1719,4 +1736,6 @@ for agent in rag_object.list_agents():
print(agent)
```
---
---

View File

@ -16,6 +16,12 @@
from .base import Base
class ChunkUpdateError(Exception):
def __init__(self, code=None, message=None, details=None):
self.code = code
self.message = message
self.details = details
super().__init__(message)
class Chunk(Base):
def __init__(self, rag, res_dict):
@ -38,4 +44,8 @@ class Chunk(Base):
res = self.put(f"/datasets/{self.dataset_id}/documents/{self.document_id}/chunks/{self.id}", update_message)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
raise ChunkUpdateError(
code=res.get("code"),
message=res.get("message"),
details=res.get("details")
)