let file in knowledgebases visible in file manager (#714)

### What problem does this PR solve?

Let file in knowledgebases visible in file manager.
#162 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh 2024-05-11 16:04:28 +08:00 committed by GitHub
parent 91b4a18c47
commit 04a9e95161
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 187 additions and 64 deletions

View File

@ -23,7 +23,7 @@ from elasticsearch_dsl import Q
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.db_models import Task from api.db.db_models import Task, File
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.task_service import TaskService, queue_tasks from api.db.services.task_service import TaskService, queue_tasks
@ -33,7 +33,7 @@ from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode from api.settings import RetCode
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@ -59,12 +59,19 @@ def upload():
return get_json_result( return get_json_result(
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
kb_root_folder = FileService.get_kb_folder(current_user.id)
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
err = [] err = []
for file in file_objs: for file in file_objs:
try: try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!") raise RuntimeError("Exceed the maximum file number of a free user!")
@ -99,6 +106,8 @@ def upload():
if re.search(r"\.(ppt|pptx|pages)$", filename): if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value doc["parser_id"] = ParserType.PRESENTATION.value
DocumentService.insert(doc) DocumentService.insert(doc)
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
except Exception as e: except Exception as e:
err.append(file.filename + ": " + str(e)) err.append(file.filename + ": " + str(e))
if err: if err:
@ -228,11 +237,13 @@ def rm():
req = request.json req = request.json
doc_ids = req["doc_id"] doc_ids = req["doc_id"]
if isinstance(doc_ids, str): doc_ids = [doc_ids] if isinstance(doc_ids, str): doc_ids = [doc_ids]
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
errors = "" errors = ""
for doc_id in doc_ids: for doc_id in doc_ids:
try: try:
e, doc = DocumentService.get_by_id(doc_id) e, doc = DocumentService.get_by_id(doc_id)
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
tenant_id = DocumentService.get_tenant_id(doc_id) tenant_id = DocumentService.get_tenant_id(doc_id)
@ -241,21 +252,25 @@ def rm():
ELASTICSEARCH.deleteByQuery( ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0) DocumentService.clear_chunk_num(doc_id)
b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
if not DocumentService.delete(doc): if not DocumentService.delete(doc):
return get_data_error_result( return get_data_error_result(
retmsg="Database error (Document removal)!") retmsg="Database error (Document removal)!")
informs = File2DocumentService.get_by_document_id(doc_id) f2d = File2DocumentService.get_by_document_id(doc_id)
if not informs: FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
MINIO.rm(doc.kb_id, doc.location) File2DocumentService.delete_by_document_id(doc_id)
else:
File2DocumentService.delete_by_document_id(doc_id) MINIO.rm(b, n)
except Exception as e: except Exception as e:
errors += str(e) errors += str(e)
if errors: return server_error_response(e) if errors:
return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)
return get_json_result(data=True) return get_json_result(data=True)

View File

@ -26,7 +26,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType from api.db import FileType, FileSource
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.settings import RetCode from api.settings import RetCode
@ -45,7 +45,7 @@ def upload():
if not pf_id: if not pf_id:
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder.id pf_id = root_folder["id"]
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
@ -132,7 +132,7 @@ def create():
input_file_type = request.json.get("type") input_file_type = request.json.get("type")
if not pf_id: if not pf_id:
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder.id pf_id = root_folder["id"]
try: try:
if not FileService.is_parent_folder_exist(pf_id): if not FileService.is_parent_folder_exist(pf_id):
@ -176,7 +176,8 @@ def list():
desc = request.args.get("desc", True) desc = request.args.get("desc", True)
if not pf_id: if not pf_id:
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder.id pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
try: try:
e, file = FileService.get_by_id(pf_id) e, file = FileService.get_by_id(pf_id)
if not e: if not e:
@ -199,7 +200,7 @@ def list():
def get_root_folder(): def get_root_folder():
try: try:
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
return get_json_result(data={"root_folder": root_folder.to_json()}) return get_json_result(data={"root_folder": root_folder})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -250,6 +251,8 @@ def rm():
return get_data_error_result(retmsg="File or Folder not found!") return get_data_error_result(retmsg="File or Folder not found!")
if not file.tenant_id: if not file.tenant_id:
return get_data_error_result(retmsg="Tenant not found!") return get_data_error_result(retmsg="Tenant not found!")
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
if file.type == FileType.FOLDER.value: if file.type == FileType.FOLDER.value:
file_id_list = FileService.get_all_innermost_file_ids(file_id, []) file_id_list = FileService.get_all_innermost_file_ids(file_id, [])

View File

@ -83,3 +83,11 @@ class ParserType(StrEnum):
NAIVE = "naive" NAIVE = "naive"
PICTURE = "picture" PICTURE = "picture"
ONE = "one" ONE = "one"
class FileSource(StrEnum):
LOCAL = ""
KNOWLEDGEBASE = "knowledgebase"
S3 = "s3"
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"

View File

@ -21,14 +21,13 @@ import operator
from functools import wraps from functools import wraps
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin from flask_login import UserMixin
from playhouse.migrate import MySQLMigrator, migrate
from peewee import ( from peewee import (
BigAutoField, BigIntegerField, BooleanField, CharField, BigIntegerField, BooleanField, CharField,
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField, CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata Field, Model, Metadata
) )
from playhouse.pool import PooledMySQLDatabase from playhouse.pool import PooledMySQLDatabase
from api.db import SerializedType, ParserType from api.db import SerializedType, ParserType
from api.settings import DATABASE, stat_logger, SECRET_KEY from api.settings import DATABASE, stat_logger, SECRET_KEY
from api.utils.log_utils import getLogger from api.utils.log_utils import getLogger
@ -344,7 +343,7 @@ class DataBaseModel(BaseModel):
@DB.connection_context() @DB.connection_context()
def init_database_tables(): def init_database_tables(alter_fields=[]):
members = inspect.getmembers(sys.modules[__name__], inspect.isclass) members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
table_objs = [] table_objs = []
create_failed_list = [] create_failed_list = []
@ -361,6 +360,7 @@ def init_database_tables():
if create_failed_list: if create_failed_list:
LOGGER.info(f"create tables failed: {create_failed_list}") LOGGER.info(f"create tables failed: {create_failed_list}")
raise Exception(f"create tables failed: {create_failed_list}") raise Exception(f"create tables failed: {create_failed_list}")
migrate_db()
def fill_db_model_object(model_object, human_model_dict): def fill_db_model_object(model_object, human_model_dict):
@ -699,6 +699,11 @@ class File(DataBaseModel):
help_text="where dose it store") help_text="where dose it store")
size = IntegerField(default=0) size = IntegerField(default=0)
type = CharField(max_length=32, null=False, help_text="file extension") type = CharField(max_length=32, null=False, help_text="file extension")
source_type = CharField(
max_length=128,
null=False,
default="",
help_text="where dose this document come from")
class Meta: class Meta:
db_table = "file" db_table = "file"
@ -817,3 +822,14 @@ class API4Conversation(DataBaseModel):
class Meta: class Meta:
db_table = "api_4_conversation" db_table = "api_4_conversation"
def migrate_db():
try:
with DB.transaction():
migrator = MySQLMigrator(DB)
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
)
except Exception as e:
pass

View File

@ -150,6 +150,22 @@ class DocumentService(CommonService):
Knowledgebase.id == kb_id).execute() Knowledgebase.id == kb_id).execute()
return num return num
@classmethod
@DB.connection_context()
def clear_chunk_num(cls, doc_id):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num-1
).where(
Knowledgebase.id == doc.kb_id).execute()
return num
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_tenant_id(cls, doc_id): def get_tenant_id(cls, doc_id):

View File

@ -15,12 +15,12 @@
# #
from datetime import datetime from datetime import datetime
from api.db import FileSource
from api.db.db_models import DB from api.db.db_models import DB
from api.db.db_models import File, Document, File2Document from api.db.db_models import File, File2Document
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService from api.utils import current_timestamp, datetime_format, get_uuid
from api.utils import current_timestamp, datetime_format
class File2DocumentService(CommonService): class File2DocumentService(CommonService):
@ -71,13 +71,15 @@ class File2DocumentService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_minio_address(cls, doc_id=None, file_id=None): def get_minio_address(cls, doc_id=None, file_id=None):
if doc_id: if doc_id:
ids = File2DocumentService.get_by_document_id(doc_id) f2d = cls.get_by_document_id(doc_id)
else: else:
ids = File2DocumentService.get_by_file_id(file_id) f2d = cls.get_by_file_id(file_id)
if ids: if f2d:
e, file = FileService.get_by_id(ids[0].file_id) file = File.get_by_id(f2d[0].file_id)
return file.parent_id, file.location if file.source_type == FileSource.LOCAL:
else: return file.parent_id, file.location
assert doc_id, "please specify doc_id" doc_id = f2d[0].document_id
e, doc = DocumentService.get_by_id(doc_id)
return doc.kb_id, doc.location assert doc_id, "please specify doc_id"
e, doc = DocumentService.get_by_id(doc_id)
return doc.kb_id, doc.location

View File

@ -16,10 +16,12 @@
from flask_login import current_user from flask_login import current_user
from peewee import fn from peewee import fn
from api.db import FileType from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource
from api.db.db_models import DB, File2Document, Knowledgebase from api.db.db_models import DB, File2Document, Knowledgebase
from api.db.db_models import File, Document from api.db.db_models import File, Document
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils import get_uuid from api.utils import get_uuid
@ -33,10 +35,15 @@ class FileService(CommonService):
if keywords: if keywords:
files = cls.model.select().where( files = cls.model.select().where(
(cls.model.tenant_id == tenant_id) (cls.model.tenant_id == tenant_id)
& (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).like(f"%%{keywords.lower()}%%"))) (cls.model.parent_id == pf_id),
(fn.LOWER(cls.model.name).like(f"%%{keywords.lower()}%%")),
~(cls.model.id == pf_id)
)
else: else:
files = cls.model.select().where((cls.model.tenant_id == tenant_id) files = cls.model.select().where((cls.model.tenant_id == tenant_id),
& (cls.model.parent_id == pf_id)) (cls.model.parent_id == pf_id),
~(cls.model.id == pf_id)
)
count = files.count() count = files.count()
if desc: if desc:
files = files.order_by(cls.model.getter_by(orderby).desc()) files = files.order_by(cls.model.getter_by(orderby).desc())
@ -135,29 +142,69 @@ class FileService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_root_folder(cls, tenant_id): def get_root_folder(cls, tenant_id):
file = cls.model.select().where(cls.model.tenant_id == tenant_id and for file in cls.model.select().where((cls.model.tenant_id == tenant_id),
cls.model.parent_id == cls.model.id) (cls.model.parent_id == cls.model.id)
if not file: ):
file_id = get_uuid() return file.to_dict()
file = {
"id": file_id,
"parent_id": file_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
cls.save(**file)
else:
file_id = file[0].id
e, file = cls.get_by_id(file_id) file_id = get_uuid()
if not e: file = {
raise RuntimeError("Database error (File retrieval)!") "id": file_id,
"parent_id": file_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
cls.save(**file)
return file return file
@classmethod
@DB.connection_context()
def get_kb_folder(cls, tenant_id):
for root in cls.model.select().where(cls.model.tenant_id == tenant_id and
cls.model.parent_id == cls.model.id):
for folder in cls.model.select().where(cls.model.tenant_id == tenant_id and
cls.model.parent_id == root.id and
cls.model.name == KNOWLEDGEBASE_FOLDER_NAME
):
return folder.to_dict()
assert False, "Can't find the KB folder. Database init error."
@classmethod
@DB.connection_context()
def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""):
for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name):
return file.to_dict()
file = {
"id": get_uuid(),
"parent_id": parent_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": name,
"type": ty,
"size": size,
"location": location,
"source_type": FileSource.KNOWLEDGEBASE
}
cls.save(**file)
return file
@classmethod
@DB.connection_context()
def init_knowledgebase_docs(cls, root_id, tenant_id):
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\
& (cls.model.parent_id == root_id)):
return
folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)
for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id==tenant_id):
kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"])
for doc in DocumentService.query(kb_id=kb.id):
FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id)
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_parent_folder(cls, file_id): def get_parent_folder(cls, file_id):
@ -241,3 +288,20 @@ class FileService(CommonService):
dfs(folder_id) dfs(folder_id)
return size return size
@classmethod
@DB.connection_context()
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
for _ in File2DocumentService.get_by_document_id(doc["id"]): return
file = {
"id": get_uuid(),
"parent_id": kb_folder_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": doc["name"],
"type": doc["type"],
"size": doc["size"],
"location": doc["location"],
"source_type": FileSource.KNOWLEDGEBASE
}
cls.save(**file)
File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]})

View File

@ -8,14 +8,14 @@ PY=/root/miniconda3/envs/py11/bin/python
function task_exe(){ function task_exe(){
while [ 1 -eq 1 ];do while [ 1 -eq 1 ];do
$PY rag/svr/task_executor.py $1 $2; $PY rag/svr/task_executor.py ;
done done
} }
WS=1 WS=1
for ((i=0;i<WS;i++)) for ((i=0;i<WS;i++))
do do
task_exe $i $WS & task_exe &
done done
while [ 1 -eq 1 ];do while [ 1 -eq 1 ];do

View File

@ -109,6 +109,7 @@ def collect():
if not msg: return pd.DataFrame() if not msg: return pd.DataFrame()
if TaskService.do_cancel(msg["id"]): if TaskService.do_cancel(msg["id"]):
cron_logger.info("Task {} has been canceled.".format(msg["id"]))
return pd.DataFrame() return pd.DataFrame()
tasks = TaskService.get_tasks(msg["id"]) tasks = TaskService.get_tasks(msg["id"])
assert tasks, "{} empty task!".format(msg["id"]) assert tasks, "{} empty task!".format(msg["id"])

View File

@ -78,8 +78,6 @@ pycryptodomex==3.20.0
pydantic==2.6.2 pydantic==2.6.2
pydantic_core==2.16.3 pydantic_core==2.16.3
PyJWT==2.8.0 PyJWT==2.8.0
PyMuPDF==1.23.25
PyMuPDFb==1.23.22
PyMySQL==1.1.0 PyMySQL==1.1.0
PyPDF2==3.0.1 PyPDF2==3.0.1
pypdfium2==4.27.0 pypdfium2==4.27.0