diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 3f9fe62bf..49cb45b31 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -116,8 +116,7 @@ def get(): @manager.route('/set', methods=['POST']) # noqa: F821 @login_required -@validate_request("doc_id", "chunk_id", "content_with_weight", - "important_kwd", "question_kwd") +@validate_request("doc_id", "chunk_id", "content_with_weight") def set(): req = request.json d = { @@ -125,14 +124,16 @@ def set(): "content_with_weight": req["content_with_weight"]} d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - if req.get("important_kwd"): + if "important_kwd" in req: d["important_kwd"] = req["important_kwd"] d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) - if req.get("question_kwd"): + if "question_kwd" in req: d["question_kwd"] = req["question_kwd"] d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) - if req.get("tag_kwd"): + if "tag_kwd" in req: d["tag_kwd"] = req["tag_kwd"] + if "tag_feas" in req: + d["tag_feas"] = req["tag_feas"] if "available_int" in req: d["available_int"] = req["available_int"] @@ -157,7 +158,7 @@ def set(): d = beAdoc(d, arr[0], arr[1], not any( [rag_tokenizer.is_chinese(t) for t in q + a])) - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) + v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 9032c38c9..b7aaee74a 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -27,12 +27,13 @@ from flask_login import login_required, current_user from api.db import LLMType from api.db.services.dialog_service import DialogService, chat, ask, label_question from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService +from api.db.services.llm_service import LLMBundle, TenantService from api import settings from api.utils.api_utils import get_json_result from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from graphrag.mind_map_extractor import MindMapExtractor + @manager.route('/set', methods=['POST']) # noqa: F821 @login_required def set_conversation(): @@ -376,8 +377,7 @@ def mindmap(): if not e: return get_data_error_result(message="Knowledgebase not found!") - embd_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) + embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) question = req["question"] ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, diff --git a/api/apps/document_app.py b/api/apps/document_app.py index afae53652..81bf1ec3c 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License # +import json import os.path import pathlib import re @@ -593,3 +594,34 @@ def parse(): txt = FileService.parse_docs(file_objs, current_user.id) return get_json_result(data=txt) + + +@manager.route('/set_meta', methods=['POST']) # noqa: F821 +@login_required +@validate_request("doc_id", "meta") +def set_meta(): + req = request.json + if not DocumentService.accessible(req["doc_id"], current_user.id): + return get_json_result( + data=False, + message='No authorization.', + code=settings.RetCode.AUTHENTICATION_ERROR + ) + try: + meta = json.loads(req["meta"]) + except Exception as e: + return get_json_result( + data=False, message=f'Json syntax error: {e}', code=settings.RetCode.ARGUMENT_ERROR) + try: + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + + if not DocumentService.update_by_id( + req["doc_id"], {"meta_fields": meta}): + return get_data_error_result( + message="Database error (meta updates)!") + + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) diff --git a/api/db/db_models.py b/api/db/db_models.py index d894846c0..ffb44259a 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -760,6 +760,7 @@ class Document(DataBaseModel): default="") process_begin_at = DateTimeField(null=True, index=True) process_duation = FloatField(default=0) + meta_fields = JSONField(null=True, default={}) run = CharField( max_length=1, @@ -1112,3 +1113,10 @@ def migrate_db(): ) except Exception: pass + try: + migrate( + migrator.add_column("document", "meta_fields", + JSONField(null=True, default={})) + ) + except Exception: + pass diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 95fba7d13..2da1bf285 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -122,15 +122,17 @@ def kb_prompt(kbinfos, max_tokens): knowledges = knowledges[:i] break + #docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]]) + #docs = {d.id: d.meta_fields for d in docs} + doc2chunks = defaultdict(list) - for i, ck in enumerate(kbinfos["chunks"]): - if i >= chunks_num: - break + for ck in kbinfos["chunks"][:chunks_num]: doc2chunks[ck["docnm_kwd"]].append(ck["content_with_weight"]) knowledges = [] for nm, chunks in doc2chunks.items(): - txt = f"Document: {nm} \nContains the following relevant fragments:\n" + txt = f"Document: {nm} \n" + txt += "Contains the following relevant fragments:\n" for i, chunk in enumerate(chunks, 1): txt += f"{i}. {chunk}\n" knowledges.append(txt)