diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 3c257e4cc..fbaafb346 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -31,7 +31,7 @@ from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService -from api.settings import RetCode +from api.settings import RetCode, retrievaler from api.utils import get_uuid, current_timestamp, datetime_format from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request from itsdangerous import URLSafeTimedSerializer @@ -39,9 +39,6 @@ from itsdangerous import URLSafeTimedSerializer from api.utils.file_utils import filename_type, thumbnail from rag.utils.minio_conn import MINIO -from rag.utils.es_conn import ELASTICSEARCH -from rag.nlp import search -from elasticsearch_dsl import Q def generate_confirmation_token(tenent_id): serializer = URLSafeTimedSerializer(tenent_id) @@ -369,27 +366,65 @@ def list_chunks(): try: if "doc_name" in form_data.keys(): tenant_id = DocumentService.get_tenant_id_by_name(form_data['doc_name']) - q = Q("match", docnm_kwd=form_data['doc_name']) + doc_id = DocumentService.get_doc_id_by_doc_name(form_data['doc_name']) elif "doc_id" in form_data.keys(): tenant_id = DocumentService.get_tenant_id(form_data['doc_id']) - q = Q("match", doc_id=form_data['doc_id']) + doc_id = form_data['doc_id'] else: return get_json_result( data=False,retmsg="Can't find doc_name or doc_id" ) - res_es_search = ELASTICSEARCH.search(q,idxnm=search.index_name(tenant_id),timeout="600s") - - res = [{} for _ in range(len(res_es_search['hits']['hits']))] - - for index , chunk in enumerate(res_es_search['hits']['hits']): - res[index]['doc_name'] = chunk['_source']['docnm_kwd'] - res[index]['content'] = chunk['_source']['content_with_weight'] - if 'img_id' in chunk['_source'].keys(): - res[index]['img_id'] = chunk['_source']['img_id'] + res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id) + res = [ + { + "content": res_item["content_with_weight"], + "doc_name": res_item["docnm_kwd"], + "img_id": res_item["img_id"] + } for res_item in res + ] except Exception as e: return server_error_response(e) return get_json_result(data=res) + + +@manager.route('/list_kb_docs', methods=['POST']) +# @login_required +def list_kb_docs(): + token = request.headers.get('Authorization').split()[1] + objs = APIToken.query(token=token) + if not objs: + return get_json_result( + data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) + + tenant_id = objs[0].tenant_id + kb_name = request.form.get("kb_name").strip() + + try: + e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id) + if not e: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") + kb_id = kb.id + + except Exception as e: + return server_error_response(e) + + page_number = int(request.form.get("page", 1)) + items_per_page = int(request.form.get("page_size", 15)) + orderby = request.form.get("orderby", "create_time") + desc = request.form.get("desc", True) + keywords = request.form.get("keywords", "") + + try: + docs, tol = DocumentService.get_by_kb_id( + kb_id, page_number, items_per_page, orderby, desc, keywords) + docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs] + + return get_json_result(data={"total": tol, "docs": docs}) + + except Exception as e: + return server_error_response(e) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index c4dc4db04..ac569563f 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -179,6 +179,17 @@ class DocumentService(CommonService): return return docs[0]["tenant_id"] + @classmethod + @DB.connection_context() + def get_doc_id_by_doc_name(cls, doc_name): + fields = [cls.model.id] + doc_id = cls.model.select(*fields) \ + .where(cls.model.name == doc_name) + doc_id = doc_id.dicts() + if not doc_id: + return + return doc_id[0]["id"] + @classmethod @DB.connection_context() def get_thumbnails(cls, docids): diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 216e9b747..e0c7c8553 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -407,3 +407,13 @@ class Dealer: except Exception as e: chat_logger.error(f"SQL failure: {sql} =>" + str(e)) return {"error": str(e)} + + def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]): + s = Search() + s = s.query(Q("match", doc_id=doc_id))[0:max_count] + s = s.to_dict() + es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields) + res = [] + for index, chunk in enumerate(es_res['hits']['hits']): + res.append({fld: chunk['_source'].get(fld) for fld in fields}) + return res