mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 13:48:59 +08:00
Improve API Documentation, Standardize Error Handling, and Enhance Comments (#5990)
### What problem does this PR solve? - The API documentation lacks detailed error code explanations. Added error code tables to `python_api_reference.md` and `http_api_reference.md` to clarify possible error codes and their meanings. - Error handling in the codebase is inconsistent. Standardized error handling logic in `sdk/python/ragflow_sdk/modules/chunk.py`. - Improved API comments by adding standardized docstrings to enhance code readability and maintainability. ### Type of change - [x] Documentation Update - [x] Refactoring
This commit is contained in:
parent
940072592f
commit
47926f7d21
@ -22,17 +22,56 @@ from api.utils import datetime_format, current_timestamp, get_uuid
|
|||||||
|
|
||||||
|
|
||||||
class CommonService:
|
class CommonService:
|
||||||
|
"""Base service class that provides common database operations.
|
||||||
|
|
||||||
|
This class serves as a foundation for all service classes in the application,
|
||||||
|
implementing standard CRUD operations and common database query patterns.
|
||||||
|
It uses the Peewee ORM for database interactions and provides a consistent
|
||||||
|
interface for database operations across all derived service classes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: The Peewee model class that this service operates on. Must be set by subclasses.
|
||||||
|
"""
|
||||||
model = None
|
model = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||||
|
"""Execute a database query with optional column selection and ordering.
|
||||||
|
|
||||||
|
This method provides a flexible way to query the database with various filters
|
||||||
|
and sorting options. It supports column selection, sort order control, and
|
||||||
|
additional filter conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cols (list, optional): List of column names to select. If None, selects all columns.
|
||||||
|
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
|
||||||
|
order_by (str, optional): Column name to sort results by.
|
||||||
|
**kwargs: Additional filter conditions passed as keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
peewee.ModelSelect: A query result containing matching records.
|
||||||
|
"""
|
||||||
return cls.model.query(cols=cols, reverse=reverse,
|
return cls.model.query(cols=cols, reverse=reverse,
|
||||||
order_by=order_by, **kwargs)
|
order_by=order_by, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_all(cls, cols=None, reverse=None, order_by=None):
|
def get_all(cls, cols=None, reverse=None, order_by=None):
|
||||||
|
"""Retrieve all records from the database with optional column selection and ordering.
|
||||||
|
|
||||||
|
This method fetches all records from the model's table with support for
|
||||||
|
column selection and result ordering. If no order_by is specified and reverse
|
||||||
|
is True, it defaults to ordering by create_time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cols (list, optional): List of column names to select. If None, selects all columns.
|
||||||
|
reverse (bool, optional): If True, sorts in descending order. If False, sorts in ascending order.
|
||||||
|
order_by (str, optional): Column name to sort results by. Defaults to 'create_time' if reverse is specified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
peewee.ModelSelect: A query containing all matching records.
|
||||||
|
"""
|
||||||
if cols:
|
if cols:
|
||||||
query_records = cls.model.select(*cols)
|
query_records = cls.model.select(*cols)
|
||||||
else:
|
else:
|
||||||
@ -51,11 +90,36 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get(cls, **kwargs):
|
def get(cls, **kwargs):
|
||||||
|
"""Get a single record matching the given criteria.
|
||||||
|
|
||||||
|
This method retrieves a single record from the database that matches
|
||||||
|
the specified filter conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Filter conditions as keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance: Single matching record.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DoesNotExist: If no matching record is found.
|
||||||
|
"""
|
||||||
return cls.model.get(**kwargs)
|
return cls.model.get(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_or_none(cls, **kwargs):
|
def get_or_none(cls, **kwargs):
|
||||||
|
"""Get a single record or None if not found.
|
||||||
|
|
||||||
|
This method attempts to retrieve a single record matching the given criteria,
|
||||||
|
returning None if no match is found instead of raising an exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Filter conditions as keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance or None: Matching record if found, None otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return cls.model.get(**kwargs)
|
return cls.model.get(**kwargs)
|
||||||
except peewee.DoesNotExist:
|
except peewee.DoesNotExist:
|
||||||
@ -64,14 +128,34 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def save(cls, **kwargs):
|
def save(cls, **kwargs):
|
||||||
# if "id" not in kwargs:
|
"""Save a new record to database.
|
||||||
# kwargs["id"] = get_uuid()
|
|
||||||
|
This method creates a new record in the database with the provided field values,
|
||||||
|
forcing an insert operation rather than an update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Record field values as keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance: The created record object.
|
||||||
|
"""
|
||||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||||
return sample_obj
|
return sample_obj
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def insert(cls, **kwargs):
|
def insert(cls, **kwargs):
|
||||||
|
"""Insert a new record with automatic ID and timestamps.
|
||||||
|
|
||||||
|
This method creates a new record with automatically generated ID and timestamp fields.
|
||||||
|
It handles the creation of create_time, create_date, update_time, and update_date fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Record field values as keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance: The newly created record object.
|
||||||
|
"""
|
||||||
if "id" not in kwargs:
|
if "id" not in kwargs:
|
||||||
kwargs["id"] = get_uuid()
|
kwargs["id"] = get_uuid()
|
||||||
kwargs["create_time"] = current_timestamp()
|
kwargs["create_time"] = current_timestamp()
|
||||||
@ -84,6 +168,15 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def insert_many(cls, data_list, batch_size=100):
|
def insert_many(cls, data_list, batch_size=100):
|
||||||
|
"""Insert multiple records in batches.
|
||||||
|
|
||||||
|
This method efficiently inserts multiple records into the database using batch processing.
|
||||||
|
It automatically sets creation timestamps for all records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_list (list): List of dictionaries containing record data to insert.
|
||||||
|
batch_size (int, optional): Number of records to insert in each batch. Defaults to 100.
|
||||||
|
"""
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
for d in data_list:
|
for d in data_list:
|
||||||
d["create_time"] = current_timestamp()
|
d["create_time"] = current_timestamp()
|
||||||
@ -94,6 +187,15 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_many_by_id(cls, data_list):
|
def update_many_by_id(cls, data_list):
|
||||||
|
"""Update multiple records by their IDs.
|
||||||
|
|
||||||
|
This method updates multiple records in the database, identified by their IDs.
|
||||||
|
It automatically updates the update_time and update_date fields for each record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_list (list): List of dictionaries containing record data to update.
|
||||||
|
Each dictionary must include an 'id' field.
|
||||||
|
"""
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
data["update_time"] = current_timestamp()
|
data["update_time"] = current_timestamp()
|
||||||
@ -104,6 +206,12 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_by_id(cls, pid, data):
|
def update_by_id(cls, pid, data):
|
||||||
|
# Update a single record by ID
|
||||||
|
# Args:
|
||||||
|
# pid: Record ID
|
||||||
|
# data: Updated field values
|
||||||
|
# Returns:
|
||||||
|
# Number of records updated
|
||||||
data["update_time"] = current_timestamp()
|
data["update_time"] = current_timestamp()
|
||||||
data["update_date"] = datetime_format(datetime.now())
|
data["update_date"] = datetime_format(datetime.now())
|
||||||
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
||||||
@ -112,6 +220,11 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_id(cls, pid):
|
def get_by_id(cls, pid):
|
||||||
|
# Get a record by ID
|
||||||
|
# Args:
|
||||||
|
# pid: Record ID
|
||||||
|
# Returns:
|
||||||
|
# Tuple of (success, record)
|
||||||
try:
|
try:
|
||||||
obj = cls.model.query(id=pid)[0]
|
obj = cls.model.query(id=pid)[0]
|
||||||
return True, obj
|
return True, obj
|
||||||
@ -121,6 +234,12 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_ids(cls, pids, cols=None):
|
def get_by_ids(cls, pids, cols=None):
|
||||||
|
# Get multiple records by their IDs
|
||||||
|
# Args:
|
||||||
|
# pids: List of record IDs
|
||||||
|
# cols: List of columns to select
|
||||||
|
# Returns:
|
||||||
|
# Query of matching records
|
||||||
if cols:
|
if cols:
|
||||||
objs = cls.model.select(*cols)
|
objs = cls.model.select(*cols)
|
||||||
else:
|
else:
|
||||||
@ -130,11 +249,21 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def delete_by_id(cls, pid):
|
def delete_by_id(cls, pid):
|
||||||
|
# Delete a record by ID
|
||||||
|
# Args:
|
||||||
|
# pid: Record ID
|
||||||
|
# Returns:
|
||||||
|
# Number of records deleted
|
||||||
return cls.model.delete().where(cls.model.id == pid).execute()
|
return cls.model.delete().where(cls.model.id == pid).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def filter_delete(cls, filters):
|
def filter_delete(cls, filters):
|
||||||
|
# Delete records matching given filters
|
||||||
|
# Args:
|
||||||
|
# filters: List of filter conditions
|
||||||
|
# Returns:
|
||||||
|
# Number of records deleted
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
num = cls.model.delete().where(*filters).execute()
|
num = cls.model.delete().where(*filters).execute()
|
||||||
return num
|
return num
|
||||||
@ -142,11 +271,23 @@ class CommonService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def filter_update(cls, filters, update_data):
|
def filter_update(cls, filters, update_data):
|
||||||
|
# Update records matching given filters
|
||||||
|
# Args:
|
||||||
|
# filters: List of filter conditions
|
||||||
|
# update_data: Updated field values
|
||||||
|
# Returns:
|
||||||
|
# Number of records updated
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
return cls.model.update(update_data).where(*filters).execute()
|
return cls.model.update(update_data).where(*filters).execute()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cut_list(tar_list, n):
|
def cut_list(tar_list, n):
|
||||||
|
# Split a list into chunks of size n
|
||||||
|
# Args:
|
||||||
|
# tar_list: List to split
|
||||||
|
# n: Chunk size
|
||||||
|
# Returns:
|
||||||
|
# List of tuples containing chunks
|
||||||
length = len(tar_list)
|
length = len(tar_list)
|
||||||
arr = range(length)
|
arr = range(length)
|
||||||
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
|
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
|
||||||
@ -156,6 +297,14 @@ class CommonService:
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def filter_scope_list(cls, in_key, in_filters_list,
|
def filter_scope_list(cls, in_key, in_filters_list,
|
||||||
filters=None, cols=None):
|
filters=None, cols=None):
|
||||||
|
# Get records matching IN clause filters with optional column selection
|
||||||
|
# Args:
|
||||||
|
# in_key: Field name for IN clause
|
||||||
|
# in_filters_list: List of values for IN clause
|
||||||
|
# filters: Additional filter conditions
|
||||||
|
# cols: List of columns to select
|
||||||
|
# Returns:
|
||||||
|
# List of matching records
|
||||||
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
||||||
if not filters:
|
if not filters:
|
||||||
filters = []
|
filters = []
|
||||||
|
@ -34,12 +34,24 @@ from rag.utils.storage_factory import STORAGE_IMPL
|
|||||||
|
|
||||||
|
|
||||||
class FileService(CommonService):
|
class FileService(CommonService):
|
||||||
|
# Service class for managing file operations and storage
|
||||||
model = File
|
model = File
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page,
|
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page,
|
||||||
orderby, desc, keywords):
|
orderby, desc, keywords):
|
||||||
|
# Get files by parent folder ID with pagination and filtering
|
||||||
|
# Args:
|
||||||
|
# tenant_id: ID of the tenant
|
||||||
|
# pf_id: Parent folder ID
|
||||||
|
# page_number: Page number for pagination
|
||||||
|
# items_per_page: Number of items per page
|
||||||
|
# orderby: Field to order by
|
||||||
|
# desc: Boolean indicating descending order
|
||||||
|
# keywords: Search keywords
|
||||||
|
# Returns:
|
||||||
|
# Tuple of (file_list, total_count)
|
||||||
if keywords:
|
if keywords:
|
||||||
files = cls.model.select().where(
|
files = cls.model.select().where(
|
||||||
(cls.model.tenant_id == tenant_id),
|
(cls.model.tenant_id == tenant_id),
|
||||||
@ -80,6 +92,11 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_id_by_file_id(cls, file_id):
|
def get_kb_id_by_file_id(cls, file_id):
|
||||||
|
# Get knowledge base IDs associated with a file
|
||||||
|
# Args:
|
||||||
|
# file_id: File ID
|
||||||
|
# Returns:
|
||||||
|
# List of dictionaries containing knowledge base IDs and names
|
||||||
kbs = (cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
|
kbs = (cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
|
||||||
.join(File2Document, on=(File2Document.file_id == file_id))
|
.join(File2Document, on=(File2Document.file_id == file_id))
|
||||||
.join(Document, on=(File2Document.document_id == Document.id))
|
.join(Document, on=(File2Document.document_id == Document.id))
|
||||||
@ -95,6 +112,12 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_pf_id_name(cls, id, name):
|
def get_by_pf_id_name(cls, id, name):
|
||||||
|
# Get file by parent folder ID and name
|
||||||
|
# Args:
|
||||||
|
# id: Parent folder ID
|
||||||
|
# name: File name
|
||||||
|
# Returns:
|
||||||
|
# File object or None if not found
|
||||||
file = cls.model.select().where((cls.model.parent_id == id) & (cls.model.name == name))
|
file = cls.model.select().where((cls.model.parent_id == id) & (cls.model.name == name))
|
||||||
if file.count():
|
if file.count():
|
||||||
e, file = cls.get_by_id(file[0].id)
|
e, file = cls.get_by_id(file[0].id)
|
||||||
@ -106,6 +129,14 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_id_list_by_id(cls, id, name, count, res):
|
def get_id_list_by_id(cls, id, name, count, res):
|
||||||
|
# Recursively get list of file IDs by traversing folder structure
|
||||||
|
# Args:
|
||||||
|
# id: Starting folder ID
|
||||||
|
# name: List of folder names to traverse
|
||||||
|
# count: Current depth in traversal
|
||||||
|
# res: List to store results
|
||||||
|
# Returns:
|
||||||
|
# List of file IDs
|
||||||
if count < len(name):
|
if count < len(name):
|
||||||
file = cls.get_by_pf_id_name(id, name[count])
|
file = cls.get_by_pf_id_name(id, name[count])
|
||||||
if file:
|
if file:
|
||||||
@ -119,6 +150,12 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_all_innermost_file_ids(cls, folder_id, result_ids):
|
def get_all_innermost_file_ids(cls, folder_id, result_ids):
|
||||||
|
# Get IDs of all files in the deepest level of folders
|
||||||
|
# Args:
|
||||||
|
# folder_id: Starting folder ID
|
||||||
|
# result_ids: List to store results
|
||||||
|
# Returns:
|
||||||
|
# List of file IDs
|
||||||
subfolders = cls.model.select().where(cls.model.parent_id == folder_id)
|
subfolders = cls.model.select().where(cls.model.parent_id == folder_id)
|
||||||
if subfolders.exists():
|
if subfolders.exists():
|
||||||
for subfolder in subfolders:
|
for subfolder in subfolders:
|
||||||
@ -130,6 +167,14 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def create_folder(cls, file, parent_id, name, count):
|
def create_folder(cls, file, parent_id, name, count):
|
||||||
|
# Recursively create folder structure
|
||||||
|
# Args:
|
||||||
|
# file: Current file object
|
||||||
|
# parent_id: Parent folder ID
|
||||||
|
# name: List of folder names to create
|
||||||
|
# count: Current depth in creation
|
||||||
|
# Returns:
|
||||||
|
# Created file object
|
||||||
if count > len(name) - 2:
|
if count > len(name) - 2:
|
||||||
return file
|
return file
|
||||||
else:
|
else:
|
||||||
@ -148,6 +193,11 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def is_parent_folder_exist(cls, parent_id):
|
def is_parent_folder_exist(cls, parent_id):
|
||||||
|
# Check if parent folder exists
|
||||||
|
# Args:
|
||||||
|
# parent_id: Parent folder ID
|
||||||
|
# Returns:
|
||||||
|
# Boolean indicating if folder exists
|
||||||
parent_files = cls.model.select().where(cls.model.id == parent_id)
|
parent_files = cls.model.select().where(cls.model.id == parent_id)
|
||||||
if parent_files.count():
|
if parent_files.count():
|
||||||
return True
|
return True
|
||||||
@ -157,6 +207,11 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_root_folder(cls, tenant_id):
|
def get_root_folder(cls, tenant_id):
|
||||||
|
# Get or create root folder for tenant
|
||||||
|
# Args:
|
||||||
|
# tenant_id: Tenant ID
|
||||||
|
# Returns:
|
||||||
|
# Root folder dictionary
|
||||||
for file in cls.model.select().where((cls.model.tenant_id == tenant_id),
|
for file in cls.model.select().where((cls.model.tenant_id == tenant_id),
|
||||||
(cls.model.parent_id == cls.model.id)
|
(cls.model.parent_id == cls.model.id)
|
||||||
):
|
):
|
||||||
@ -179,6 +234,11 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_folder(cls, tenant_id):
|
def get_kb_folder(cls, tenant_id):
|
||||||
|
# Get knowledge base folder for tenant
|
||||||
|
# Args:
|
||||||
|
# tenant_id: Tenant ID
|
||||||
|
# Returns:
|
||||||
|
# Knowledge base folder dictionary
|
||||||
for root in cls.model.select().where(
|
for root in cls.model.select().where(
|
||||||
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
|
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
|
||||||
for folder in cls.model.select().where(
|
for folder in cls.model.select().where(
|
||||||
@ -190,6 +250,16 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
|
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
|
||||||
|
# Create a new file from knowledge base
|
||||||
|
# Args:
|
||||||
|
# tenant_id: Tenant ID
|
||||||
|
# name: File name
|
||||||
|
# parent_id: Parent folder ID
|
||||||
|
# ty: File type
|
||||||
|
# size: File size
|
||||||
|
# location: File location
|
||||||
|
# Returns:
|
||||||
|
# Created file dictionary
|
||||||
for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name):
|
for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name):
|
||||||
return file.to_dict()
|
return file.to_dict()
|
||||||
file = {
|
file = {
|
||||||
@ -209,6 +279,10 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def init_knowledgebase_docs(cls, root_id, tenant_id):
|
def init_knowledgebase_docs(cls, root_id, tenant_id):
|
||||||
|
# Initialize knowledge base documents
|
||||||
|
# Args:
|
||||||
|
# root_id: Root folder ID
|
||||||
|
# tenant_id: Tenant ID
|
||||||
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\
|
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\
|
||||||
& (cls.model.parent_id == root_id)):
|
& (cls.model.parent_id == root_id)):
|
||||||
return
|
return
|
||||||
@ -222,6 +296,11 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_parent_folder(cls, file_id):
|
def get_parent_folder(cls, file_id):
|
||||||
|
# Get parent folder of a file
|
||||||
|
# Args:
|
||||||
|
# file_id: File ID
|
||||||
|
# Returns:
|
||||||
|
# Parent folder object
|
||||||
file = cls.model.select().where(cls.model.id == file_id)
|
file = cls.model.select().where(cls.model.id == file_id)
|
||||||
if file.count():
|
if file.count():
|
||||||
e, file = cls.get_by_id(file[0].parent_id)
|
e, file = cls.get_by_id(file[0].parent_id)
|
||||||
@ -234,6 +313,11 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_all_parent_folders(cls, start_id):
|
def get_all_parent_folders(cls, start_id):
|
||||||
|
# Get all parent folders in path
|
||||||
|
# Args:
|
||||||
|
# start_id: Starting file ID
|
||||||
|
# Returns:
|
||||||
|
# List of parent folder objects
|
||||||
parent_folders = []
|
parent_folders = []
|
||||||
current_id = start_id
|
current_id = start_id
|
||||||
while current_id:
|
while current_id:
|
||||||
@ -249,6 +333,11 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def insert(cls, file):
|
def insert(cls, file):
|
||||||
|
# Insert a new file record
|
||||||
|
# Args:
|
||||||
|
# file: File data dictionary
|
||||||
|
# Returns:
|
||||||
|
# Created file object
|
||||||
if not cls.save(**file):
|
if not cls.save(**file):
|
||||||
raise RuntimeError("Database error (File)!")
|
raise RuntimeError("Database error (File)!")
|
||||||
return File(**file)
|
return File(**file)
|
||||||
@ -256,6 +345,7 @@ class FileService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def delete(cls, file):
|
def delete(cls, file):
|
||||||
|
#
|
||||||
return cls.delete_by_id(file.id)
|
return cls.delete_by_id(file.id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -20,21 +20,69 @@ from peewee import fn
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgebaseService(CommonService):
|
class KnowledgebaseService(CommonService):
|
||||||
|
"""Service class for managing knowledge base operations.
|
||||||
|
|
||||||
|
This class extends CommonService to provide specialized functionality for knowledge base
|
||||||
|
management, including document parsing status tracking, access control, and configuration
|
||||||
|
management. It handles operations such as listing, creating, updating, and deleting
|
||||||
|
knowledge bases, as well as managing their associated documents and permissions.
|
||||||
|
|
||||||
|
The class implements a comprehensive set of methods for:
|
||||||
|
- Document parsing status verification
|
||||||
|
- Knowledge base access control
|
||||||
|
- Parser configuration management
|
||||||
|
- Tenant-based knowledge base organization
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: The Knowledgebase model class for database operations.
|
||||||
|
"""
|
||||||
model = Knowledgebase
|
model = Knowledgebase
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def is_parsed_done(cls, kb_id):
|
def accessible4deletion(cls, kb_id, user_id):
|
||||||
"""
|
"""Check if a knowledge base can be deleted by a specific user.
|
||||||
Check if all documents in the knowledge base have completed parsing
|
|
||||||
|
This method verifies whether a user has permission to delete a knowledge base
|
||||||
|
by checking if they are the creator of that knowledge base.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kb_id: Knowledge base ID
|
kb_id (str): The unique identifier of the knowledge base to check.
|
||||||
|
user_id (str): The unique identifier of the user attempting the deletion.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If all documents are parsed successfully, returns (True, None)
|
bool: True if the user has permission to delete the knowledge base,
|
||||||
If any document is not fully parsed, returns (False, error_message)
|
False if the user doesn't have permission or the knowledge base doesn't exist.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> KnowledgebaseService.accessible4deletion("kb123", "user456")
|
||||||
|
True
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- This method only checks creator permissions
|
||||||
|
- A return value of False can mean either:
|
||||||
|
1. The knowledge base doesn't exist
|
||||||
|
2. The user is not the creator of the knowledge base
|
||||||
"""
|
"""
|
||||||
|
# Check if a knowledge base can be deleted by a user
|
||||||
|
docs = cls.model.select(
|
||||||
|
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
|
||||||
|
docs = docs.dicts()
|
||||||
|
if not docs:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def is_parsed_done(cls, kb_id):
|
||||||
|
# Check if all documents in the knowledge base have completed parsing
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# kb_id: Knowledge base ID
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# If all documents are parsed successfully, returns (True, None)
|
||||||
|
# If any document is not fully parsed, returns (False, error_message)
|
||||||
from api.db import TaskStatus
|
from api.db import TaskStatus
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
|
|
||||||
@ -61,6 +109,11 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def list_documents_by_ids(cls,kb_ids):
|
def list_documents_by_ids(cls,kb_ids):
|
||||||
|
# Get document IDs associated with given knowledge base IDs
|
||||||
|
# Args:
|
||||||
|
# kb_ids: List of knowledge base IDs
|
||||||
|
# Returns:
|
||||||
|
# List of document IDs
|
||||||
doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where(
|
doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where(
|
||||||
cls.model.id.in_(kb_ids)
|
cls.model.id.in_(kb_ids)
|
||||||
)
|
)
|
||||||
@ -75,6 +128,18 @@ class KnowledgebaseService(CommonService):
|
|||||||
orderby, desc, keywords,
|
orderby, desc, keywords,
|
||||||
parser_id=None
|
parser_id=None
|
||||||
):
|
):
|
||||||
|
# Get knowledge bases by tenant IDs with pagination and filtering
|
||||||
|
# Args:
|
||||||
|
# joined_tenant_ids: List of tenant IDs
|
||||||
|
# user_id: Current user ID
|
||||||
|
# page_number: Page number for pagination
|
||||||
|
# items_per_page: Number of items per page
|
||||||
|
# orderby: Field to order by
|
||||||
|
# desc: Boolean indicating descending order
|
||||||
|
# keywords: Search keywords
|
||||||
|
# parser_id: Optional parser ID filter
|
||||||
|
# Returns:
|
||||||
|
# Tuple of (knowledge_base_list, total_count)
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.avatar,
|
cls.model.avatar,
|
||||||
@ -122,6 +187,11 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_ids(cls, tenant_id):
|
def get_kb_ids(cls, tenant_id):
|
||||||
|
# Get all knowledge base IDs for a tenant
|
||||||
|
# Args:
|
||||||
|
# tenant_id: Tenant ID
|
||||||
|
# Returns:
|
||||||
|
# List of knowledge base IDs
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
]
|
]
|
||||||
@ -132,9 +202,13 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_detail(cls, kb_id):
|
def get_detail(cls, kb_id):
|
||||||
|
# Get detailed information about a knowledge base
|
||||||
|
# Args:
|
||||||
|
# kb_id: Knowledge base ID
|
||||||
|
# Returns:
|
||||||
|
# Dictionary containing knowledge base details
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
# Tenant.embd_id,
|
|
||||||
cls.model.embd_id,
|
cls.model.embd_id,
|
||||||
cls.model.avatar,
|
cls.model.avatar,
|
||||||
cls.model.name,
|
cls.model.name,
|
||||||
@ -155,17 +229,21 @@ class KnowledgebaseService(CommonService):
|
|||||||
if not kbs:
|
if not kbs:
|
||||||
return
|
return
|
||||||
d = kbs[0].to_dict()
|
d = kbs[0].to_dict()
|
||||||
# d["embd_id"] = kbs[0].tenant.embd_id
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_parser_config(cls, id, config):
|
def update_parser_config(cls, id, config):
|
||||||
|
# Update parser configuration for a knowledge base
|
||||||
|
# Args:
|
||||||
|
# id: Knowledge base ID
|
||||||
|
# config: New parser configuration
|
||||||
e, m = cls.get_by_id(id)
|
e, m = cls.get_by_id(id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError(f"knowledgebase({id}) not found.")
|
raise LookupError(f"knowledgebase({id}) not found.")
|
||||||
|
|
||||||
def dfs_update(old, new):
|
def dfs_update(old, new):
|
||||||
|
# Deep update of nested configuration
|
||||||
for k, v in new.items():
|
for k, v in new.items():
|
||||||
if k not in old:
|
if k not in old:
|
||||||
old[k] = v
|
old[k] = v
|
||||||
@ -185,6 +263,11 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_field_map(cls, ids):
|
def get_field_map(cls, ids):
|
||||||
|
# Get field mappings for knowledge bases
|
||||||
|
# Args:
|
||||||
|
# ids: List of knowledge base IDs
|
||||||
|
# Returns:
|
||||||
|
# Dictionary of field mappings
|
||||||
conf = {}
|
conf = {}
|
||||||
for k in cls.get_by_ids(ids):
|
for k in cls.get_by_ids(ids):
|
||||||
if k.parser_config and "field_map" in k.parser_config:
|
if k.parser_config and "field_map" in k.parser_config:
|
||||||
@ -194,6 +277,12 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_name(cls, kb_name, tenant_id):
|
def get_by_name(cls, kb_name, tenant_id):
|
||||||
|
# Get knowledge base by name and tenant ID
|
||||||
|
# Args:
|
||||||
|
# kb_name: Knowledge base name
|
||||||
|
# tenant_id: Tenant ID
|
||||||
|
# Returns:
|
||||||
|
# Tuple of (exists, knowledge_base)
|
||||||
kb = cls.model.select().where(
|
kb = cls.model.select().where(
|
||||||
(cls.model.name == kb_name)
|
(cls.model.name == kb_name)
|
||||||
& (cls.model.tenant_id == tenant_id)
|
& (cls.model.tenant_id == tenant_id)
|
||||||
@ -206,12 +295,27 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_all_ids(cls):
|
def get_all_ids(cls):
|
||||||
|
# Get all knowledge base IDs
|
||||||
|
# Returns:
|
||||||
|
# List of all knowledge base IDs
|
||||||
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
|
return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_list(cls, joined_tenant_ids, user_id,
|
def get_list(cls, joined_tenant_ids, user_id,
|
||||||
page_number, items_per_page, orderby, desc, id, name):
|
page_number, items_per_page, orderby, desc, id, name):
|
||||||
|
# Get list of knowledge bases with filtering and pagination
|
||||||
|
# Args:
|
||||||
|
# joined_tenant_ids: List of tenant IDs
|
||||||
|
# user_id: Current user ID
|
||||||
|
# page_number: Page number for pagination
|
||||||
|
# items_per_page: Number of items per page
|
||||||
|
# orderby: Field to order by
|
||||||
|
# desc: Boolean indicating descending order
|
||||||
|
# id: Optional ID filter
|
||||||
|
# name: Optional name filter
|
||||||
|
# Returns:
|
||||||
|
# List of knowledge bases
|
||||||
kbs = cls.model.select()
|
kbs = cls.model.select()
|
||||||
if id:
|
if id:
|
||||||
kbs = kbs.where(cls.model.id == id)
|
kbs = kbs.where(cls.model.id == id)
|
||||||
@ -235,6 +339,12 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def accessible(cls, kb_id, user_id):
|
def accessible(cls, kb_id, user_id):
|
||||||
|
# Check if a knowledge base is accessible by a user
|
||||||
|
# Args:
|
||||||
|
# kb_id: Knowledge base ID
|
||||||
|
# user_id: User ID
|
||||||
|
# Returns:
|
||||||
|
# Boolean indicating accessibility
|
||||||
docs = cls.model.select(
|
docs = cls.model.select(
|
||||||
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||||
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
|
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||||
@ -246,6 +356,12 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_by_id(cls, kb_id, user_id):
|
def get_kb_by_id(cls, kb_id, user_id):
|
||||||
|
# Get knowledge base by ID and user ID
|
||||||
|
# Args:
|
||||||
|
# kb_id: Knowledge base ID
|
||||||
|
# user_id: User ID
|
||||||
|
# Returns:
|
||||||
|
# List containing knowledge base information
|
||||||
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||||
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
|
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
|
||||||
kbs = kbs.dicts()
|
kbs = kbs.dicts()
|
||||||
@ -254,18 +370,14 @@ class KnowledgebaseService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_kb_by_name(cls, kb_name, user_id):
|
def get_kb_by_name(cls, kb_name, user_id):
|
||||||
|
# Get knowledge base by name and user ID
|
||||||
|
# Args:
|
||||||
|
# kb_name: Knowledge base name
|
||||||
|
# user_id: User ID
|
||||||
|
# Returns:
|
||||||
|
# List containing knowledge base information
|
||||||
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
|
||||||
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
|
).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
|
||||||
kbs = kbs.dicts()
|
kbs = kbs.dicts()
|
||||||
return list(kbs)
|
return list(kbs)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@DB.connection_context()
|
|
||||||
def accessible4deletion(cls, kb_id, user_id):
|
|
||||||
docs = cls.model.select(
|
|
||||||
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
|
|
||||||
docs = docs.dicts()
|
|
||||||
if not docs:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
@ -36,6 +36,12 @@ from rag.nlp import search
|
|||||||
|
|
||||||
|
|
||||||
def trim_header_by_lines(text: str, max_length) -> str:
|
def trim_header_by_lines(text: str, max_length) -> str:
|
||||||
|
# Trim header text to maximum length while preserving line breaks
|
||||||
|
# Args:
|
||||||
|
# text: Input text to trim
|
||||||
|
# max_length: Maximum allowed length
|
||||||
|
# Returns:
|
||||||
|
# Trimmed text
|
||||||
len_text = len(text)
|
len_text = len(text)
|
||||||
if len_text <= max_length:
|
if len_text <= max_length:
|
||||||
return text
|
return text
|
||||||
@ -46,11 +52,37 @@ def trim_header_by_lines(text: str, max_length) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class TaskService(CommonService):
|
class TaskService(CommonService):
|
||||||
|
"""Service class for managing document processing tasks.
|
||||||
|
|
||||||
|
This class extends CommonService to provide specialized functionality for document
|
||||||
|
processing task management, including task creation, progress tracking, and chunk
|
||||||
|
management. It handles various document types (PDF, Excel, etc.) and manages their
|
||||||
|
processing lifecycle.
|
||||||
|
|
||||||
|
The class implements a robust task queue system with retry mechanisms and progress
|
||||||
|
tracking, supporting both synchronous and asynchronous task execution.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: The Task model class for database operations.
|
||||||
|
"""
|
||||||
model = Task
|
model = Task
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_task(cls, task_id):
|
def get_task(cls, task_id):
|
||||||
|
"""Retrieve detailed task information by task ID.
|
||||||
|
|
||||||
|
This method fetches comprehensive task details including associated document,
|
||||||
|
knowledge base, and tenant information. It also handles task retry logic and
|
||||||
|
progress updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): The unique identifier of the task to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Task details dictionary containing all task information and related metadata.
|
||||||
|
Returns None if task is not found or has exceeded retry limit.
|
||||||
|
"""
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.doc_id,
|
cls.model.doc_id,
|
||||||
@ -105,6 +137,18 @@ class TaskService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_tasks(cls, doc_id: str):
|
def get_tasks(cls, doc_id: str):
|
||||||
|
"""Retrieve all tasks associated with a document.
|
||||||
|
|
||||||
|
This method fetches all processing tasks for a given document, ordered by page
|
||||||
|
number and creation time. It includes task progress and chunk information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id (str): The unique identifier of the document.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: List of task dictionaries containing task details.
|
||||||
|
Returns None if no tasks are found.
|
||||||
|
"""
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.from_page,
|
cls.model.from_page,
|
||||||
@ -124,11 +168,31 @@ class TaskService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||||
|
"""Update the chunk IDs associated with a task.
|
||||||
|
|
||||||
|
This method updates the chunk_ids field of a task, which stores the IDs of
|
||||||
|
processed document chunks in a space-separated string format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id (str): The unique identifier of the task.
|
||||||
|
chunk_ids (str): Space-separated string of chunk identifiers.
|
||||||
|
"""
|
||||||
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
|
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_ongoing_doc_name(cls):
|
def get_ongoing_doc_name(cls):
|
||||||
|
"""Get names of documents that are currently being processed.
|
||||||
|
|
||||||
|
This method retrieves information about documents that are in the processing state,
|
||||||
|
including their locations and associated IDs. It uses database locking to ensure
|
||||||
|
thread safety when accessing the task information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
|
||||||
|
for documents currently being processed. Returns empty list if
|
||||||
|
no documents are being processed.
|
||||||
|
"""
|
||||||
with DB.lock("get_task", -1):
|
with DB.lock("get_task", -1):
|
||||||
docs = (
|
docs = (
|
||||||
cls.model.select(
|
cls.model.select(
|
||||||
@ -172,6 +236,18 @@ class TaskService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def do_cancel(cls, id):
|
def do_cancel(cls, id):
|
||||||
|
"""Check if a task should be cancelled based on its document status.
|
||||||
|
|
||||||
|
This method determines whether a task should be cancelled by checking the
|
||||||
|
associated document's run status and progress. A task should be cancelled
|
||||||
|
if its document is marked for cancellation or has negative progress.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id (str): The unique identifier of the task to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the task should be cancelled, False otherwise.
|
||||||
|
"""
|
||||||
task = cls.model.get_by_id(id)
|
task = cls.model.get_by_id(id)
|
||||||
_, doc = DocumentService.get_by_id(task.doc_id)
|
_, doc = DocumentService.get_by_id(task.doc_id)
|
||||||
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
|
||||||
@ -179,6 +255,18 @@ class TaskService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_progress(cls, id, info):
|
def update_progress(cls, id, info):
|
||||||
|
"""Update the progress information for a task.
|
||||||
|
|
||||||
|
This method updates both the progress message and completion percentage of a task.
|
||||||
|
It handles platform-specific behavior (macOS vs others) and uses database locking
|
||||||
|
when necessary to ensure thread safety.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id (str): The unique identifier of the task to update.
|
||||||
|
info (dict): Dictionary containing progress information with keys:
|
||||||
|
- progress_msg (str, optional): Progress message to append
|
||||||
|
- progress (float, optional): Progress percentage (0.0 to 1.0)
|
||||||
|
"""
|
||||||
if os.environ.get("MACOS"):
|
if os.environ.get("MACOS"):
|
||||||
if info["progress_msg"]:
|
if info["progress_msg"]:
|
||||||
task = cls.model.get_by_id(id)
|
task = cls.model.get_by_id(id)
|
||||||
@ -202,6 +290,24 @@ class TaskService(CommonService):
|
|||||||
|
|
||||||
|
|
||||||
def queue_tasks(doc: dict, bucket: str, name: str):
|
def queue_tasks(doc: dict, bucket: str, name: str):
|
||||||
|
"""Create and queue document processing tasks.
|
||||||
|
|
||||||
|
This function creates processing tasks for a document based on its type and configuration.
|
||||||
|
It handles different document types (PDF, Excel, etc.) differently and manages task
|
||||||
|
chunking and configuration. It also implements task reuse optimization by checking
|
||||||
|
for previously completed tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc (dict): Document dictionary containing metadata and configuration.
|
||||||
|
bucket (str): Storage bucket name where the document is stored.
|
||||||
|
name (str): File name of the document.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- For PDF documents, tasks are created per page range based on configuration
|
||||||
|
- For Excel documents, tasks are created per row range
|
||||||
|
- Task digests are calculated for optimization and reuse
|
||||||
|
- Previous task chunks may be reused if available
|
||||||
|
"""
|
||||||
def new_task():
|
def new_task():
|
||||||
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
|
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
|
||||||
|
|
||||||
@ -279,6 +385,26 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
|||||||
|
|
||||||
|
|
||||||
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
||||||
|
"""Attempt to reuse chunks from previous tasks for optimization.
|
||||||
|
|
||||||
|
This function checks if chunks from previously completed tasks can be reused for
|
||||||
|
the current task, which can significantly improve processing efficiency. It matches
|
||||||
|
tasks based on page ranges and configuration digests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (dict): Current task dictionary to potentially reuse chunks for.
|
||||||
|
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
|
||||||
|
chunking_config (dict): Configuration dictionary for chunk processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Chunks can only be reused if:
|
||||||
|
- A previous task exists with matching page range and configuration digest
|
||||||
|
- The previous task was completed successfully (progress = 1.0)
|
||||||
|
- The previous task has valid chunk IDs
|
||||||
|
"""
|
||||||
idx = 0
|
idx = 0
|
||||||
while idx < len(prev_tasks):
|
while idx < len(prev_tasks):
|
||||||
prev_task = prev_tasks[idx]
|
prev_task = prev_tasks[idx]
|
||||||
|
@ -29,11 +29,27 @@ from rag.settings import MINIO
|
|||||||
|
|
||||||
|
|
||||||
class UserService(CommonService):
|
class UserService(CommonService):
|
||||||
|
"""Service class for managing user-related database operations.
|
||||||
|
|
||||||
|
This class extends CommonService to provide specialized functionality for user management,
|
||||||
|
including authentication, user creation, updates, and deletions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: The User model class for database operations.
|
||||||
|
"""
|
||||||
model = User
|
model = User
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def filter_by_id(cls, user_id):
|
def filter_by_id(cls, user_id):
|
||||||
|
"""Retrieve a user by their ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The unique identifier of the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User object if found, None otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
user = cls.model.select().where(cls.model.id == user_id).get()
|
user = cls.model.select().where(cls.model.id == user_id).get()
|
||||||
return user
|
return user
|
||||||
@ -43,6 +59,15 @@ class UserService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def query_user(cls, email, password):
|
def query_user(cls, email, password):
|
||||||
|
"""Authenticate a user with email and password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: User's email address.
|
||||||
|
password: User's password in plain text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User object if authentication successful, None otherwise.
|
||||||
|
"""
|
||||||
user = cls.model.select().where((cls.model.email == email),
|
user = cls.model.select().where((cls.model.email == email),
|
||||||
(cls.model.status == StatusEnum.VALID.value)).first()
|
(cls.model.status == StatusEnum.VALID.value)).first()
|
||||||
if user and check_password_hash(str(user.password), password):
|
if user and check_password_hash(str(user.password), password):
|
||||||
@ -85,6 +110,14 @@ class UserService(CommonService):
|
|||||||
|
|
||||||
|
|
||||||
class TenantService(CommonService):
|
class TenantService(CommonService):
|
||||||
|
"""Service class for managing tenant-related database operations.
|
||||||
|
|
||||||
|
This class extends CommonService to provide functionality for tenant management,
|
||||||
|
including tenant information retrieval and credit management.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: The Tenant model class for database operations.
|
||||||
|
"""
|
||||||
model = Tenant
|
model = Tenant
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -136,6 +169,14 @@ class TenantService(CommonService):
|
|||||||
|
|
||||||
|
|
||||||
class UserTenantService(CommonService):
|
class UserTenantService(CommonService):
|
||||||
|
"""Service class for managing user-tenant relationship operations.
|
||||||
|
|
||||||
|
This class extends CommonService to handle the many-to-many relationship
|
||||||
|
between users and tenants, managing user roles and tenant memberships.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: The UserTenant model class for database operations.
|
||||||
|
"""
|
||||||
model = UserTenant
|
model = UserTenant
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -531,6 +531,26 @@ Failure:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Error Codes
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
| Code | Message | Description |
|
||||||
|
|------|---------|-------------|
|
||||||
|
| 400 | Bad Request | Invalid request parameters |
|
||||||
|
| 401 | Unauthorized | Unauthorized access |
|
||||||
|
| 403 | Forbidden | Access denied |
|
||||||
|
| 404 | Not Found | Resource not found |
|
||||||
|
| 500 | Internal Server Error | Server internal error |
|
||||||
|
| 1001 | Invalid Chunk ID | Invalid Chunk ID |
|
||||||
|
| 1002 | Chunk Update Failed | Chunk update failed |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## FILE MANAGEMENT WITHIN DATASET
|
## FILE MANAGEMENT WITHIN DATASET
|
||||||
|
|
||||||
---
|
---
|
||||||
|
@ -317,6 +317,23 @@ dataset = rag_object.list_datasets(name="kb_name")
|
|||||||
dataset.update({"embedding_model":"BAAI/bge-zh-v1.5", "chunk_method":"manual"})
|
dataset.update({"embedding_model":"BAAI/bge-zh-v1.5", "chunk_method":"manual"})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Error Codes
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
| Code | Message | Description |
|
||||||
|
|------|---------|-------------|
|
||||||
|
| 400 | Bad Request | Invalid request parameters |
|
||||||
|
| 401 | Unauthorized | Unauthorized access |
|
||||||
|
| 403 | Forbidden | Access denied |
|
||||||
|
| 404 | Not Found | Resource not found |
|
||||||
|
| 500 | Internal Server Error | Server internal error |
|
||||||
|
| 1001 | Invalid Chunk ID | Invalid Chunk ID |
|
||||||
|
| 1002 | Chunk Update Failed | Chunk update failed |
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## FILE MANAGEMENT WITHIN DATASET
|
## FILE MANAGEMENT WITHIN DATASET
|
||||||
@ -1719,4 +1736,6 @@ for agent in rag_object.list_agents():
|
|||||||
print(agent)
|
print(agent)
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +16,12 @@
|
|||||||
|
|
||||||
from .base import Base
|
from .base import Base
|
||||||
|
|
||||||
|
class ChunkUpdateError(Exception):
|
||||||
|
def __init__(self, code=None, message=None, details=None):
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
self.details = details
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
class Chunk(Base):
|
class Chunk(Base):
|
||||||
def __init__(self, rag, res_dict):
|
def __init__(self, rag, res_dict):
|
||||||
@ -38,4 +44,8 @@ class Chunk(Base):
|
|||||||
res = self.put(f"/datasets/{self.dataset_id}/documents/{self.document_id}/chunks/{self.id}", update_message)
|
res = self.put(f"/datasets/{self.dataset_id}/documents/{self.document_id}/chunks/{self.id}", update_message)
|
||||||
res = res.json()
|
res = res.json()
|
||||||
if res.get("code") != 0:
|
if res.get("code") != 0:
|
||||||
raise Exception(res["message"])
|
raise ChunkUpdateError(
|
||||||
|
code=res.get("code"),
|
||||||
|
message=res.get("message"),
|
||||||
|
details=res.get("details")
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user