From c5da3cdd97701c74d4a07160aff163b16d95a9f9 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 9 Jan 2025 17:07:21 +0800 Subject: [PATCH] Tagging (#4426) ### What problem does this PR solve? #4367 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- agent/component/retrieval.py | 4 +- api/apps/api_app.py | 5 +- api/apps/chunk_app.py | 25 ++-- api/apps/conversation_app.py | 9 +- api/apps/kb_app.py | 80 ++++++++++- api/apps/sdk/dataset.py | 8 +- api/apps/sdk/dify_retrieval.py | 4 +- api/apps/sdk/doc.py | 4 +- api/db/__init__.py | 1 + api/db/init_data.py | 11 +- api/db/services/dialog_service.py | 88 +++++++++++- api/db/services/file2document_service.py | 10 +- api/db/services/file_service.py | 5 +- api/db/services/knowledgebase_service.py | 7 +- api/db/services/task_service.py | 1 + api/settings.py | 2 +- api/utils/api_utils.py | 71 +++++---- conf/infinity_mapping.json | 4 +- graphrag/utils.py | 21 ++- rag/app/qa.py | 33 +++-- rag/app/tag.py | 125 ++++++++++++++++ rag/nlp/query.py | 35 ++++- rag/nlp/search.py | 168 ++++++++++++++++------ rag/settings.py | 3 + rag/svr/task_executor.py | 71 +++++++-- rag/utils/__init__.py | 2 + rag/utils/doc_store_conn.py | 14 +- rag/utils/es_conn.py | 109 +++++++++----- rag/utils/infinity_conn.py | 16 ++- sdk/python/test/test_sdk_api/t_dataset.py | 2 +- 30 files changed, 736 insertions(+), 202 deletions(-) create mode 100644 rag/app/tag.py diff --git a/agent/component/retrieval.py b/agent/component/retrieval.py index f9f3e9fa2..69f4a27fa 100644 --- a/agent/component/retrieval.py +++ b/agent/component/retrieval.py @@ -19,6 +19,7 @@ from abc import ABC import pandas as pd from api.db import LLMType +from api.db.services.dialog_service import label_question from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api import settings @@ -70,7 +71,8 @@ class Retrieval(ComponentBase, ABC): kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, 1, self._param.top_n, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, - aggs=False, rerank_mdl=rerank_mdl) + aggs=False, rerank_mdl=rerank_mdl, + rank_feature=label_question(query, kbs)) if not kbinfos["chunks"]: df = Retrieval.be_output("") diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 095c5b231..cb260b6b2 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -25,7 +25,7 @@ from api.db import FileType, LLMType, ParserType, FileSource from api.db.db_models import APIToken, Task, File from api.db.services import duplicate_name from api.db.services.api_service import APITokenService, API4ConversationService -from api.db.services.dialog_service import DialogService, chat, keyword_extraction +from api.db.services.dialog_service import DialogService, chat, keyword_extraction, label_question from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService @@ -840,7 +840,8 @@ def retrieval(): question += keyword_extraction(chat_mdl, question) ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl) + doc_ids, rerank_mdl=rerank_mdl, + rank_feature=label_question(question, kbs)) for c in ranks["chunks"]: c.pop("vector", None) return get_json_result(data=ranks) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 2edf69902..3f9fe62bf 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -19,9 +19,10 @@ import json from flask import request from flask_login import login_required, current_user -from api.db.services.dialog_service import keyword_extraction +from api.db.services.dialog_service import keyword_extraction, label_question from rag.app.qa import rmPrefix, beAdoc from rag.nlp import search, rag_tokenizer +from rag.settings import PAGERANK_FLD from rag.utils import rmSpace from api.db import LLMType, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService @@ -124,10 +125,14 @@ def set(): "content_with_weight": req["content_with_weight"]} d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - d["important_kwd"] = req["important_kwd"] - d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) - d["question_kwd"] = req["question_kwd"] - d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) + if req.get("important_kwd"): + d["important_kwd"] = req["important_kwd"] + d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) + if req.get("question_kwd"): + d["question_kwd"] = req["question_kwd"] + d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) + if req.get("tag_kwd"): + d["tag_kwd"] = req["tag_kwd"] if "available_int" in req: d["available_int"] = req["available_int"] @@ -220,7 +225,7 @@ def create(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") - d["kb_id"] = doc.kb_id + d["kb_id"] = [doc.kb_id] d["docnm_kwd"] = doc.name d["title_tks"] = rag_tokenizer.tokenize(doc.name) d["doc_id"] = doc.id @@ -233,7 +238,7 @@ def create(): if not e: return get_data_error_result(message="Knowledgebase not found!") if kb.pagerank: - d["pagerank_fea"] = kb.pagerank + d[PAGERANK_FLD] = kb.pagerank embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) @@ -294,12 +299,16 @@ def retrieval_test(): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) + labels = label_question(question, [kb]) retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) + doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), + rank_feature=labels + ) for c in ranks["chunks"]: c.pop("vector", None) + ranks["labels"] = labels return get_json_result(data=ranks) except Exception as e: diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index c98abd69c..9032c38c9 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -25,7 +25,7 @@ from flask import request, Response from flask_login import login_required, current_user from api.db import LLMType -from api.db.services.dialog_service import DialogService, chat, ask +from api.db.services.dialog_service import DialogService, chat, ask, label_question from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService from api import settings @@ -379,8 +379,11 @@ def mindmap(): embd_mdl = TenantLLMService.model_instance( kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) - ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, - 0.3, 0.3, aggs=False) + question = req["question"] + ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, + 0.3, 0.3, aggs=False, + rank_feature=label_question(question, [kb]) + ) mindmap = MindMapExtractor(chat_mdl) mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output if "error" in mind_map: diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index b120d3cfb..35cc48351 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -30,6 +30,7 @@ from api.utils.api_utils import get_json_result from api import settings from rag.nlp import search from api.constants import DATASET_NAME_LIMIT +from rag.settings import PAGERANK_FLD @manager.route('/create', methods=['post']) # noqa: F821 @@ -104,11 +105,11 @@ def update(): if kb.pagerank != req.get("pagerank", 0): if req.get("pagerank", 0) > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]}, + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) else: - # Elasticsearch requires pagerank_fea be non-zero! - settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"}, + # Elasticsearch requires PAGERANK_FLD be non-zero! + settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id) @@ -150,12 +151,14 @@ def list_kbs(): keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 1)) items_per_page = int(request.args.get("page_size", 150)) + parser_id = request.args.get("parser_id") orderby = request.args.get("orderby", "create_time") desc = request.args.get("desc", True) try: tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) kbs, total = KnowledgebaseService.get_by_tenant_ids( - [m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords) + [m["tenant_id"] for m in tenants], current_user.id, page_number, + items_per_page, orderby, desc, keywords, parser_id) return get_json_result(data={"kbs": kbs, "total": total}) except Exception as e: return server_error_response(e) @@ -199,3 +202,72 @@ def rm(): return get_json_result(data=True) except Exception as e: return server_error_response(e) + + +@manager.route('//tags', methods=['GET']) # noqa: F821 +@login_required +def list_tags(kb_id): + if not KnowledgebaseService.accessible(kb_id, current_user.id): + return get_json_result( + data=False, + message='No authorization.', + code=settings.RetCode.AUTHENTICATION_ERROR + ) + + tags = settings.retrievaler.all_tags(current_user.id, [kb_id]) + return get_json_result(data=tags) + + +@manager.route('/tags', methods=['GET']) # noqa: F821 +@login_required +def list_tags_from_kbs(): + kb_ids = request.args.get("kb_ids", "").split(",") + for kb_id in kb_ids: + if not KnowledgebaseService.accessible(kb_id, current_user.id): + return get_json_result( + data=False, + message='No authorization.', + code=settings.RetCode.AUTHENTICATION_ERROR + ) + + tags = settings.retrievaler.all_tags(current_user.id, kb_ids) + return get_json_result(data=tags) + + +@manager.route('//rm_tags', methods=['POST']) # noqa: F821 +@login_required +def rm_tags(kb_id): + req = request.json + if not KnowledgebaseService.accessible(kb_id, current_user.id): + return get_json_result( + data=False, + message='No authorization.', + code=settings.RetCode.AUTHENTICATION_ERROR + ) + e, kb = KnowledgebaseService.get_by_id(kb_id) + + for t in req["tags"]: + settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, + {"remove": {"tag_kwd": t}}, + search.index_name(kb.tenant_id), + kb_id) + return get_json_result(data=True) + + +@manager.route('//rename_tag', methods=['POST']) # noqa: F821 +@login_required +def rename_tags(kb_id): + req = request.json + if not KnowledgebaseService.accessible(kb_id, current_user.id): + return get_json_result( + data=False, + message='No authorization.', + code=settings.RetCode.AUTHENTICATION_ERROR + ) + e, kb = KnowledgebaseService.get_by_id(kb_id) + + settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, + {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}}, + search.index_name(kb.tenant_id), + kb_id) + return get_json_result(data=True) \ No newline at end of file diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index d44956948..3fdcfebc6 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -73,7 +73,8 @@ def create(tenant_id): chunk_method: type: string enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", - "presentation", "picture", "one", "knowledge_graph", "email"] + "presentation", "picture", "one", "knowledge_graph", "email", "tag" + ] description: Chunking method. parser_config: type: object @@ -108,6 +109,7 @@ def create(tenant_id): "one", "knowledge_graph", "email", + "tag" ] check_validation = valid( permission, @@ -302,7 +304,8 @@ def update(tenant_id, dataset_id): chunk_method: type: string enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", - "presentation", "picture", "one", "knowledge_graph", "email"] + "presentation", "picture", "one", "knowledge_graph", "email", "tag" + ] description: Updated chunking method. parser_config: type: object @@ -339,6 +342,7 @@ def update(tenant_id, dataset_id): "one", "knowledge_graph", "email", + "tag" ] check_validation = valid( permission, diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 3a81c8534..d1c61c6b5 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -16,6 +16,7 @@ from flask import request, jsonify from api.db import LLMType, ParserType +from api.db.services.dialog_service import label_question from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api import settings @@ -54,7 +55,8 @@ def retrieval(tenant_id): page_size=top, similarity_threshold=similarity_threshold, vector_similarity_weight=0.3, - top=top + top=top, + rank_feature=label_question(question, [kb]) ) records = [] for c in ranks["chunks"]: diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 74b2e6c95..fa1ce58cf 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -16,7 +16,7 @@ import pathlib import datetime -from api.db.services.dialog_service import keyword_extraction +from api.db.services.dialog_service import keyword_extraction, label_question from rag.app.qa import rmPrefix, beAdoc from rag.nlp import rag_tokenizer from api.db import LLMType, ParserType @@ -276,6 +276,7 @@ def update_doc(tenant_id, dataset_id, document_id): "one", "knowledge_graph", "email", + "tag" } if req.get("chunk_method") not in valid_chunk_method: return get_error_data_result( @@ -1355,6 +1356,7 @@ def retrieval_test(tenant_id): doc_ids, rerank_mdl=rerank_mdl, highlight=highlight, + rank_feature=label_question(question, kbs) ) for c in ranks["chunks"]: c.pop("vector", None) diff --git a/api/db/__init__.py b/api/db/__init__.py index c4cee6b6f..8c8a6535b 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -89,6 +89,7 @@ class ParserType(StrEnum): AUDIO = "audio" EMAIL = "email" KG = "knowledge_graph" + TAG = "tag" class FileSource(StrEnum): diff --git a/api/db/init_data.py b/api/db/init_data.py index 4817b05fd..54993892d 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -133,7 +133,7 @@ def init_llm_factory(): TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"}) TenantService.filter_update([1 == 1], { - "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"}) + "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"}) ## insert openai two embedding models to the current openai user. # print("Start to insert 2 OpenAI embedding models...") tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()]) @@ -153,14 +153,7 @@ def init_llm_factory(): break for kb_id in KnowledgebaseService.get_all_ids(): KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)}) - """ - drop table llm; - drop table llm_factories; - update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph'; - alter table knowledgebase modify avatar longtext; - alter table user modify avatar longtext; - alter table dialog modify icon longtext; - """ + def add_graph_templates(): diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index a86e4aad8..95fba7d13 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -29,8 +29,10 @@ from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api import settings +from graphrag.utils import get_tags_from_cache, set_tags_to_cache from rag.app.resume import forbidden_select_fields4resume from rag.nlp.search import index_name +from rag.settings import TAG_FLD from rag.utils import rmSpace, num_tokens_from_string, encoder from api.utils.file_utils import get_project_base_directory @@ -135,6 +137,29 @@ def kb_prompt(kbinfos, max_tokens): return knowledges +def label_question(question, kbs): + tags = None + tag_kb_ids = [] + for kb in kbs: + if kb.parser_config.get("tag_kb_ids"): + tag_kb_ids.extend(kb.parser_config["tag_kb_ids"]) + if tag_kb_ids: + all_tags = get_tags_from_cache(tag_kb_ids) + if not all_tags: + all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids) + set_tags_to_cache(all_tags, tag_kb_ids) + else: + all_tags = json.loads(all_tags) + tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids) + tags = settings.retrievaler.tag_query(question, + list(set([kb.tenant_id for kb in tag_kbs])), + tag_kb_ids, + all_tags, + kb.parser_config.get("topn_tags", 3) + ) + return tags + + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." @@ -236,11 +261,14 @@ def chat(dialog, messages, stream=True, **kwargs): generate_keyword_ts = timer() tenant_ids = list(set([kb.tenant_id for kb in kbs])) + kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, doc_ids=attachments, - top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) + top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, + rank_feature=label_question(" ".join(questions), kbs) + ) retrieval_ts = timer() @@ -650,7 +678,10 @@ def ask(question, kb_ids, tenant_id): chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) - kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) + kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, + 1, 12, 0.1, 0.3, aggs=False, + rank_feature=label_question(question, kbs) + ) knowledges = kb_prompt(kbinfos, max_tokens) prompt = """ Role: You're a smart assistant. Your name is Miss R. @@ -700,3 +731,56 @@ def ask(question, kb_ids, tenant_id): answer = ans yield {"answer": answer, "reference": {}} yield decorate_answer(answer) + + +def content_tagging(chat_mdl, content, all_tags, examples, topn=3): + prompt = f""" +Role: You're a text analyzer. + +Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set. + +Steps:: + - Comprehend the tag/label set. + - Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON. + - Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score. + +Requirements + - The tags MUST be from the tag set. + - The output MUST be in JSON format only, the key is tag and the value is its relevance score. + - The relevance score must be range from 1 to 10. + - Keywords ONLY in output. + +# TAG SET +{", ".join(all_tags)} + +""" + for i, ex in enumerate(examples): + prompt += """ +# Examples {} +### Text Content +{} + +Output: +{} + + """.format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)) + + prompt += f""" +# Real Data +### Text Content +{content} + +""" + msg = [ + {"role": "system", "content": prompt}, + {"role": "user", "content": "Output: "} + ] + _, msg = message_fit_in(msg, chat_mdl.max_length) + kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5}) + if isinstance(kwd, tuple): + kwd = kwd[0] + if kwd.find("**ERROR**") >= 0: + raise Exception(kwd) + + kwd = re.sub(r".*?\{", "{", kwd) + return json.loads(kwd) \ No newline at end of file diff --git a/api/db/services/file2document_service.py b/api/db/services/file2document_service.py index f3f587e46..c03dbf928 100644 --- a/api/db/services/file2document_service.py +++ b/api/db/services/file2document_service.py @@ -43,10 +43,7 @@ class File2DocumentService(CommonService): def insert(cls, obj): if not cls.save(**obj): raise RuntimeError("Database error (File)!") - e, obj = cls.get_by_id(obj["id"]) - if not e: - raise RuntimeError("Database error (File retrieval)!") - return obj + return File2Document(**obj) @classmethod @DB.connection_context() @@ -63,9 +60,8 @@ class File2DocumentService(CommonService): def update_by_file_id(cls, file_id, obj): obj["update_time"] = current_timestamp() obj["update_date"] = datetime_format(datetime.now()) - # num = cls.model.update(obj).where(cls.model.id == file_id).execute() - e, obj = cls.get_by_id(cls.model.id) - return obj + cls.model.update(obj).where(cls.model.id == file_id).execute() + return File2Document(**obj) @classmethod @DB.connection_context() diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 48a665f06..4b0b3e56a 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -251,10 +251,7 @@ class FileService(CommonService): def insert(cls, file): if not cls.save(**file): raise RuntimeError("Database error (File)!") - e, file = cls.get_by_id(file["id"]) - if not e: - raise RuntimeError("Database error (File retrieval)!") - return file + return File(**file) @classmethod @DB.connection_context() diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 357849964..a4f5d0095 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -35,7 +35,10 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_by_tenant_ids(cls, joined_tenant_ids, user_id, - page_number, items_per_page, orderby, desc, keywords): + page_number, items_per_page, + orderby, desc, keywords, + parser_id=None + ): fields = [ cls.model.id, cls.model.avatar, @@ -67,6 +70,8 @@ class KnowledgebaseService(CommonService): cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value) ) + if parser_id: + kbs = kbs.where(cls.model.parser_id == parser_id) if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) else: diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 3991b2291..f317be358 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -69,6 +69,7 @@ class TaskService(CommonService): Knowledgebase.language, Knowledgebase.embd_id, Knowledgebase.pagerank, + Knowledgebase.parser_config.alias("kb_parser_config"), Tenant.img2txt_id, Tenant.asr_id, Tenant.llm_id, diff --git a/api/settings.py b/api/settings.py index cd9154864..3c3d29574 100644 --- a/api/settings.py +++ b/api/settings.py @@ -140,7 +140,7 @@ def init_settings(): API_KEY = LLM.get("api_key", "") PARSERS = LLM.get( "parsers", - "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") + "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag") HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index f33588584..02d6d701d 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -173,6 +173,7 @@ def validate_request(*args, **kwargs): return wrapper + def not_allowed_parameters(*params): def decorator(f): def wrapper(*args, **kwargs): @@ -182,7 +183,9 @@ def not_allowed_parameters(*params): return get_json_result( code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") return f(*args, **kwargs) + return wrapper + return decorator @@ -207,6 +210,7 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None) response = {"code": code, "message": message, "data": data} return jsonify(response) + def apikey_required(func): @wraps(func) def decorated_function(*args, **kwargs): @@ -282,17 +286,18 @@ def construct_error_response(e): def token_required(func): @wraps(func) def decorated_function(*args, **kwargs): - authorization_str=flask_request.headers.get('Authorization') + authorization_str = flask_request.headers.get('Authorization') if not authorization_str: - return get_json_result(data=False,message="`Authorization` can't be empty") - authorization_list=authorization_str.split() + return get_json_result(data=False, message="`Authorization` can't be empty") + authorization_list = authorization_str.split() if len(authorization_list) < 2: - return get_json_result(data=False,message="Please check your authorization format.") + return get_json_result(data=False, message="Please check your authorization format.") token = authorization_list[1] objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Authentication error: API key is invalid!', code=settings.RetCode.AUTHENTICATION_ERROR + data=False, message='Authentication error: API key is invalid!', + code=settings.RetCode.AUTHENTICATION_ERROR ) kwargs['tenant_id'] = objs[0].tenant_id return func(*args, **kwargs) @@ -330,35 +335,41 @@ def generate_confirmation_token(tenent_id): return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34] -def valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method): - if valid_parameter(permission,valid_permission): - return valid_parameter(permission,valid_permission) - if valid_parameter(language,valid_language): - return valid_parameter(language,valid_language) - if valid_parameter(chunk_method,valid_chunk_method): - return valid_parameter(chunk_method,valid_chunk_method) +def valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method): + if valid_parameter(permission, valid_permission): + return valid_parameter(permission, valid_permission) + if valid_parameter(language, valid_language): + return valid_parameter(language, valid_language) + if valid_parameter(chunk_method, valid_chunk_method): + return valid_parameter(chunk_method, valid_chunk_method) -def valid_parameter(parameter,valid_values): + +def valid_parameter(parameter, valid_values): if parameter and parameter not in valid_values: - return get_error_data_result(f"'{parameter}' is not in {valid_values}") + return get_error_data_result(f"'{parameter}' is not in {valid_values}") -def get_parser_config(chunk_method,parser_config): + +def get_parser_config(chunk_method, parser_config): if parser_config: return parser_config if not chunk_method: chunk_method = "naive" - key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"use_raptor": False}}, - "qa":{"raptor":{"use_raptor":False}}, - "resume":None, - "manual":{"raptor":{"use_raptor":False}}, - "table":None, - "paper":{"raptor":{"use_raptor":False}}, - "book":{"raptor":{"use_raptor":False}}, - "laws":{"raptor":{"use_raptor":False}}, - "presentation":{"raptor":{"use_raptor":False}}, - "one":None, - "knowledge_graph":{"chunk_token_num":8192,"delimiter":"\\n!?;。;!?","entity_types":["organization","person","location","event","time"]}, - "email":None, - "picture":None} - parser_config=key_mapping[chunk_method] - return parser_config \ No newline at end of file + key_mapping = { + "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": True, + "raptor": {"use_raptor": False}}, + "qa": {"raptor": {"use_raptor": False}}, + "tag": None, + "resume": None, + "manual": {"raptor": {"use_raptor": False}}, + "table": None, + "paper": {"raptor": {"use_raptor": False}}, + "book": {"raptor": {"use_raptor": False}}, + "laws": {"raptor": {"use_raptor": False}}, + "presentation": {"raptor": {"use_raptor": False}}, + "one": None, + "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?", + "entity_types": ["organization", "person", "location", "event", "time"]}, + "email": None, + "picture": None} + parser_config = key_mapping[chunk_method] + return parser_config diff --git a/conf/infinity_mapping.json b/conf/infinity_mapping.json index b6c6642a1..61f9d1bae 100644 --- a/conf/infinity_mapping.json +++ b/conf/infinity_mapping.json @@ -10,6 +10,7 @@ "title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, + "tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, @@ -27,5 +28,6 @@ "available_int": {"type": "integer", "default": 1}, "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "pagerank_fea": {"type": "integer", "default": 0} + "pagerank_fea": {"type": "integer", "default": 0}, + "tag_fea": {"type": "integer", "default": 0} } diff --git a/graphrag/utils.py b/graphrag/utils.py index 37de11543..2873ab729 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -111,4 +111,23 @@ def set_embed_cache(llmnm, txt, arr): k = hasher.hexdigest() arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr) - REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600) \ No newline at end of file + REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600) + + +def get_tags_from_cache(kb_ids): + hasher = xxhash.xxh64() + hasher.update(str(kb_ids).encode("utf-8")) + + k = hasher.hexdigest() + bin = REDIS_CONN.get(k) + if not bin: + return + return bin + + +def set_tags_to_cache(kb_ids, tags): + hasher = xxhash.xxh64() + hasher.update(str(kb_ids).encode("utf-8")) + + k = hasher.hexdigest() + REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600) diff --git a/rag/app/qa.py b/rag/app/qa.py index 824a398bd..00cde2d64 100644 --- a/rag/app/qa.py +++ b/rag/app/qa.py @@ -26,6 +26,7 @@ from docx import Document from PIL import Image from markdown import markdown + class Excel(ExcelParser): def __call__(self, fnm, binary=None, callback=None): if not binary: @@ -58,11 +59,11 @@ class Excel(ExcelParser): if len(res) % 999 == 0: callback(len(res) * 0.6 / - total, ("Extract Q&A: {}".format(len(res)) + + total, ("Extract pairs: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + ( + callback(0.6, ("Extract pairs: {}. ".format(len(res)) + ( f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) self.is_english = is_english( [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1]) @@ -269,7 +270,7 @@ def beAdocPdf(d, q, a, eng, image, poss): return d -def beAdocDocx(d, q, a, eng, image): +def beAdocDocx(d, q, a, eng, image, row_num=-1): qprefix = "Question: " if eng else "问题:" aprefix = "Answer: " if eng else "回答:" d["content_with_weight"] = "\t".join( @@ -277,16 +278,20 @@ def beAdocDocx(d, q, a, eng, image): d["content_ltks"] = rag_tokenizer.tokenize(q) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["image"] = image + if row_num >= 0: + d["top_int"] = [row_num] return d -def beAdoc(d, q, a, eng): +def beAdoc(d, q, a, eng, row_num=-1): qprefix = "Question: " if eng else "问题:" aprefix = "Answer: " if eng else "回答:" d["content_with_weight"] = "\t".join( [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) d["content_ltks"] = rag_tokenizer.tokenize(q) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) + if row_num >= 0: + d["top_int"] = [row_num] return d @@ -316,8 +321,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): if re.search(r"\.xlsx?$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") excel_parser = Excel() - for q, a in excel_parser(filename, binary, callback): - res.append(beAdoc(deepcopy(doc), q, a, eng)) + for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)): + res.append(beAdoc(deepcopy(doc), q, a, eng, ii)) return res elif re.search(r"\.(txt)$", filename, re.IGNORECASE): @@ -344,7 +349,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): fails.append(str(i+1)) elif len(arr) == 2: if question and answer: - res.append(beAdoc(deepcopy(doc), question, answer, eng)) + res.append(beAdoc(deepcopy(doc), question, answer, eng, i)) question, answer = arr i += 1 if len(res) % 999 == 0: @@ -352,7 +357,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) if question: - res.append(beAdoc(deepcopy(doc), question, answer, eng)) + res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines))) callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) @@ -378,14 +383,14 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): fails.append(str(i + 1)) elif len(row) == 2: if question and answer: - res.append(beAdoc(deepcopy(doc), question, answer, eng)) + res.append(beAdoc(deepcopy(doc), question, answer, eng, i)) question, answer = row if len(res) % 999 == 0: callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) if question: - res.append(beAdoc(deepcopy(doc), question, answer, eng)) + res.append(beAdoc(deepcopy(doc), question, answer, eng, len(reader))) callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) @@ -420,7 +425,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): if last_answer.strip(): sum_question = '\n'.join(question_stack) if sum_question: - res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng)) + res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) last_answer = '' i = question_level @@ -432,7 +437,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): if last_answer.strip(): sum_question = '\n'.join(question_stack) if sum_question: - res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng)) + res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) return res elif re.search(r"\.docx$", filename, re.IGNORECASE): @@ -440,8 +445,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): qai_list, tbls = docx_parser(filename, binary, from_page=0, to_page=10000, callback=callback) res = tokenize_table(tbls, doc, eng) - for q, a, image in qai_list: - res.append(beAdocDocx(deepcopy(doc), q, a, eng, image)) + for i, (q, a, image) in enumerate(qai_list): + res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i)) return res raise NotImplementedError( diff --git a/rag/app/tag.py b/rag/app/tag.py new file mode 100644 index 000000000..cd59ff7bd --- /dev/null +++ b/rag/app/tag.py @@ -0,0 +1,125 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +import csv +from copy import deepcopy + +from deepdoc.parser.utils import get_text +from rag.app.qa import Excel +from rag.nlp import rag_tokenizer + + +def beAdoc(d, q, a, eng, row_num=-1): + d["content_with_weight"] = q + d["content_ltks"] = rag_tokenizer.tokenize(q) + d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) + d["tag_kwd"] = [t.strip() for t in a.split(",") if t.strip()] + if row_num >= 0: + d["top_int"] = [row_num] + return d + + +def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): + """ + Excel and csv(txt) format files are supported. + If the file is in excel format, there should be 2 column content and tags without header. + And content column is ahead of tags column. + And it's O.K if it has multiple sheets as long as the columns are rightly composed. + + If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags. + + All the deformed lines will be ignored. + Every pair will be treated as a chunk. + """ + eng = lang.lower() == "english" + res = [] + doc = { + "docnm_kwd": filename, + "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) + } + if re.search(r"\.xlsx?$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + excel_parser = Excel() + for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)): + res.append(beAdoc(deepcopy(doc), q, a, eng, ii)) + return res + + elif re.search(r"\.(txt)$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + txt = get_text(filename, binary) + lines = txt.split("\n") + comma, tab = 0, 0 + for line in lines: + if len(line.split(",")) == 2: + comma += 1 + if len(line.split("\t")) == 2: + tab += 1 + delimiter = "\t" if tab >= comma else "," + + fails = [] + content = "" + i = 0 + while i < len(lines): + arr = lines[i].split(delimiter) + if len(arr) != 2: + content += "\n" + lines[i] + elif len(arr) == 2: + content += "\n" + arr[0] + res.append(beAdoc(deepcopy(doc), content, arr[1], eng, i)) + content = "" + i += 1 + if len(res) % 999 == 0: + callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + + callback(0.6, ("Extract TAG: {}".format(len(res)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + + return res + + elif re.search(r"\.(csv)$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + txt = get_text(filename, binary) + lines = txt.split("\n") + delimiter = "\t" if any("\t" in line for line in lines) else "," + + fails = [] + content = "" + res = [] + reader = csv.reader(lines, delimiter=delimiter) + + for i, row in enumerate(reader): + if len(row) != 2: + content += "\n" + lines[i] + elif len(row) == 2: + content += "\n" + row[0] + res.append(beAdoc(deepcopy(doc), content, row[1], eng, i)) + content = "" + if len(res) % 999 == 0: + callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + + callback(0.6, ("Extract TAG : {}".format(len(res)) + ( + f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) + return res + + raise NotImplementedError( + "Excel, csv(txt) format files are supported.") + + +if __name__ == "__main__": + import sys + + def dummy(prog=None, msg=""): + pass + chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) \ No newline at end of file diff --git a/rag/nlp/query.py b/rag/nlp/query.py index ab5c22e32..af96a8722 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -59,13 +59,15 @@ class FulltextQueryer: "", ), (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), - (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", " ") + ( + r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", + " ") ] for r, p in patts: txt = re.sub(r, p, txt, flags=re.IGNORECASE) return txt - def question(self, txt, tbl="qa", min_match:float=0.6): + def question(self, txt, tbl="qa", min_match: float = 0.6): txt = re.sub( r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+", " ", @@ -90,7 +92,8 @@ class FulltextQueryer: syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()] syns.append(" ".join(syn)) - q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)] + q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if + tk and not re.match(r"[.^+\(\)-]", tk)] for i in range(1, len(tks_w)): left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip() if not left or not right: @@ -155,7 +158,7 @@ class FulltextQueryer: if len(keywords) < 32: keywords.extend([s for s in tk_syns if s]) tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] - tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns] + tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] if len(keywords) >= 32: break @@ -174,8 +177,6 @@ class FulltextQueryer: if len(twts) > 1: tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt) - if re.match(r"[0-9a-z ]+$", tt): - tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt) syns = " OR ".join( [ @@ -232,3 +233,25 @@ class FulltextQueryer: for k, v in qtwt.items(): q += v return s / q + + def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30): + if isinstance(content_tks, str): + content_tks = [c.strip() for c in content_tks.strip() if c.strip()] + tks_w = self.tw.weights(content_tks, preprocess=False) + + keywords = [f'"{k.strip()}"' for k in keywords] + for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]: + tk_syns = self.syn.lookup(tk) + tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] + tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] + tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] + tk = FulltextQueryer.subSpecialChar(tk) + if tk.find(" ") > 0: + tk = '"%s"' % tk + if tk_syns: + tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns) + if tk: + keywords.append(f"{tk}^{w}") + + return MatchTextExpr(self.query_fields, " ".join(keywords), 100, + {"minimum_should_match": min(3, len(keywords) / 10)}) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 388fe3fe9..d4cbb45c8 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import logging import re from dataclasses import dataclass +from rag.settings import TAG_FLD, PAGERANK_FLD from rag.utils import rmSpace from rag.nlp import rag_tokenizer, query import numpy as np @@ -47,7 +47,8 @@ class Dealer: qv, _ = emb_mdl.encode_queries(txt) shape = np.array(qv).shape if len(shape) > 1: - raise Exception(f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).") + raise Exception( + f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).") embedding_data = [float(v) for v in qv] vector_column_name = f"q_{len(embedding_data)}_vec" return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity}) @@ -63,7 +64,12 @@ class Dealer: condition[key] = req[key] return condition - def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=None, highlight = False): + def search(self, req, idx_names: str | list[str], + kb_ids: list[str], + emb_mdl=None, + highlight=False, + rank_feature: dict | None = None + ): filters = self.get_filters(req) orderBy = OrderByExpr() @@ -72,9 +78,11 @@ class Dealer: ps = int(req.get("size", topk)) offset, limit = pg * ps, ps - src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int", - "doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd", "question_kwd", "question_tks", - "available_int", "content_with_weight", "pagerank_fea"]) + src = req.get("fields", + ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int", + "doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd", + "question_kwd", "question_tks", + "available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD]) kwds = set([]) qst = req.get("question", "") @@ -85,15 +93,16 @@ class Dealer: orderBy.asc("top_int") orderBy.desc("create_timestamp_flt") res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) - total=self.dataStore.getTotal(res) + total = self.dataStore.getTotal(res) logging.debug("Dealer.search TOTAL: {}".format(total)) else: highlightFields = ["content_ltks", "title_tks"] if highlight else [] matchText, keywords = self.qryr.question(qst, min_match=0.3) if emb_mdl is None: matchExprs = [matchText] - res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids) - total=self.dataStore.getTotal(res) + res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, + idx_names, kb_ids, rank_feature=rank_feature) + total = self.dataStore.getTotal(res) logging.debug("Dealer.search TOTAL: {}".format(total)) else: matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) @@ -103,8 +112,9 @@ class Dealer: fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"}) matchExprs = [matchText, matchDense, fusionExpr] - res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids) - total=self.dataStore.getTotal(res) + res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, + idx_names, kb_ids, rank_feature=rank_feature) + total = self.dataStore.getTotal(res) logging.debug("Dealer.search TOTAL: {}".format(total)) # If result is empty, try again with lower min_match @@ -112,8 +122,9 @@ class Dealer: matchText, _ = self.qryr.question(qst, min_match=0.1) filters.pop("doc_ids", None) matchDense.extra_options["similarity"] = 0.17 - res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids) - total=self.dataStore.getTotal(res) + res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], + orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) + total = self.dataStore.getTotal(res) logging.debug("Dealer.search 2 TOTAL: {}".format(total)) for k in keywords: @@ -126,8 +137,8 @@ class Dealer: kwds.add(kk) logging.debug(f"TOTAL: {total}") - ids=self.dataStore.getChunkIds(res) - keywords=list(kwds) + ids = self.dataStore.getChunkIds(res) + keywords = list(kwds) highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") aggs = self.dataStore.getAggregation(res, "docnm_kwd") return self.SearchResult( @@ -188,13 +199,13 @@ class Dealer: ans_v, _ = embd_mdl.encode(pieces_) assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( - len(ans_v[0]), len(chunk_v[0])) + len(ans_v[0]), len(chunk_v[0])) chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split() for ck in chunks] cites = {} thr = 0.63 - while thr>0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks: + while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks: for i, a in enumerate(pieces_): sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], chunk_v, @@ -228,20 +239,44 @@ class Dealer: return res, seted + def _rank_feature_scores(self, query_rfea, search_res): + ## For rank feature(tag_fea) scores. + rank_fea = [] + pageranks = [] + for chunk_id in search_res.ids: + pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0)) + pageranks = np.array(pageranks, dtype=float) + + if not query_rfea: + return np.array([0 for _ in range(len(search_res.ids))]) + pageranks + + q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD])) + for i in search_res.ids: + nor, denor = 0, 0 + for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items(): + if t in query_rfea: + nor += query_rfea[t] * sc + denor += sc * sc + if denor == 0: + rank_fea.append(0) + else: + rank_fea.append(nor/np.sqrt(denor)/q_denor) + return np.array(rank_fea)*10. + pageranks + def rerank(self, sres, query, tkweight=0.3, - vtweight=0.7, cfield="content_ltks"): + vtweight=0.7, cfield="content_ltks", + rank_feature: dict | None = None + ): _, keywords = self.qryr.question(query) vector_size = len(sres.query_vector) vector_column = f"q_{vector_size}_vec" zero_vector = [0.0] * vector_size ins_embd = [] - pageranks = [] for chunk_id in sres.ids: vector = sres.field[chunk_id].get(vector_column, zero_vector) if isinstance(vector, str): vector = [float(v) for v in vector.split("\t")] ins_embd.append(vector) - pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0)) if not ins_embd: return [], [], [] @@ -254,18 +289,22 @@ class Dealer: title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t] question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t] important_kwd = sres.field[i].get("important_kwd", []) - tks = content_ltks + title_tks*2 + important_kwd*5 + question_tks*6 + tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6 ins_tw.append(tks) + ## For rank feature(tag_fea) scores. + rank_fea = self._rank_feature_scores(rank_feature, sres) + sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, ins_embd, keywords, ins_tw, tkweight, vtweight) - return sim+np.array(pageranks, dtype=float), tksim, vtsim + return sim + rank_fea, tksim, vtsim def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, - vtweight=0.7, cfield="content_ltks"): + vtweight=0.7, cfield="content_ltks", + rank_feature: dict | None = None): _, keywords = self.qryr.question(query) for i in sres.ids: @@ -280,9 +319,11 @@ class Dealer: ins_tw.append(tks) tksim = self.qryr.token_similarity(keywords, ins_tw) - vtsim,_ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw]) + vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw]) + ## For rank feature(tag_fea) scores. + rank_fea = self._rank_feature_scores(rank_feature, sres) - return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim + return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): return self.qryr.hybrid_similarity(ans_embd, @@ -291,13 +332,15 @@ class Dealer: rag_tokenizer.tokenize(inst).split()) def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2, - vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False): + vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, + rerank_mdl=None, highlight=False, + rank_feature: dict | None = {PAGERANK_FLD: 10}): ranks = {"total": 0, "chunks": [], "doc_aggs": {}} if not question: return ranks RERANK_PAGE_LIMIT = 3 - req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128), + req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size * RERANK_PAGE_LIMIT, 128), "question": question, "vector": True, "topk": top, "similarity": similarity_threshold, "available_int": 1} @@ -309,29 +352,30 @@ class Dealer: if isinstance(tenant_ids, str): tenant_ids = tenant_ids.split(",") - sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight) + sres = self.search(req, [index_name(tid) for tid in tenant_ids], + kb_ids, embd_mdl, highlight, rank_feature=rank_feature) ranks["total"] = sres.total if page <= RERANK_PAGE_LIMIT: if rerank_mdl and sres.total > 0: sim, tsim, vsim = self.rerank_by_model(rerank_mdl, - sres, question, 1 - vector_similarity_weight, vector_similarity_weight) + sres, question, 1 - vector_similarity_weight, + vector_similarity_weight, + rank_feature=rank_feature) else: sim, tsim, vsim = self.rerank( - sres, question, 1 - vector_similarity_weight, vector_similarity_weight) - idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size] + sres, question, 1 - vector_similarity_weight, vector_similarity_weight, + rank_feature=rank_feature) + idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size] else: - sim = tsim = vsim = [1]*len(sres.ids) + sim = tsim = vsim = [1] * len(sres.ids) idx = list(range(len(sres.ids))) - def floor_sim(score): - return (int(score * 100.)%100)/100. - dim = len(sres.query_vector) vector_column = f"q_{dim}_vec" zero_vector = [0.0] * dim for i in idx: - if floor_sim(sim[i]) < similarity_threshold: + if sim[i] < similarity_threshold: break if len(ranks["chunks"]) >= page_size: if aggs: @@ -369,8 +413,8 @@ class Dealer: ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k, - v in sorted(ranks["doc_aggs"].items(), - key=lambda x:x[1]["count"] * -1)] + v in sorted(ranks["doc_aggs"].items(), + key=lambda x: x[1]["count"] * -1)] ranks["chunks"] = ranks["chunks"][:page_size] return ranks @@ -379,15 +423,57 @@ class Dealer: tbl = self.dataStore.sql(sql, fetch_size, format) return tbl - def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]): + def chunk_list(self, doc_id: str, tenant_id: str, + kb_ids: list[str], max_count=1024, + offset=0, + fields=["docnm_kwd", "content_with_weight", "img_id"]): condition = {"doc_id": doc_id} res = [] bs = 128 - for p in range(0, max_count, bs): - es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), kb_ids) + for p in range(offset, max_count, bs): + es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), + kb_ids) dict_chunks = self.dataStore.getFields(es_res, fields) if dict_chunks: res.extend(dict_chunks.values()) if len(dict_chunks.values()) < bs: break return res + + def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000): + res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) + return self.dataStore.getAggregation(res, "tag_kwd") + + def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000): + res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) + res = self.dataStore.getAggregation(res, "tag_kwd") + total = np.sum([c for _, c in res]) + return {t: (c + 1) / (total + S) for t, c in res} + + def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000): + idx_nm = index_name(tenant_id) + match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn) + res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"]) + aggs = self.dataStore.getAggregation(res, "tag_kwd") + if not aggs: + return False + cnt = np.sum([c for _, c in aggs]) + tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs], + key=lambda x: x[1] * -1)[:topn_tags] + doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0} + return True + + def tag_query(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000): + if isinstance(tenant_ids, str): + idx_nms = index_name(tenant_ids) + else: + idx_nms = [index_name(tid) for tid in tenant_ids] + match_txt, _ = self.qryr.question(question, min_match=0.0) + res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"]) + aggs = self.dataStore.getAggregation(res, "tag_kwd") + if not aggs: + return {} + cnt = np.sum([c for _, c in aggs]) + tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs], + key=lambda x: x[1] * -1)[:topn_tags] + return {a: c for a, c in tag_fea if c > 0} diff --git a/rag/settings.py b/rag/settings.py index af6075d85..83e087484 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -38,6 +38,9 @@ SVR_QUEUE_RETENTION = 60*60 SVR_QUEUE_MAX_LEN = 1024 SVR_CONSUMER_NAME = "rag_flow_svr_consumer" SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group" +PAGERANK_FLD = "pagerank_fea" +TAG_FLD = "tag_feas" + def print_rag_settings(): logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 390a2dce1..a81ce55d5 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -16,10 +16,10 @@ # from beartype import BeartypeConf # from beartype.claw import beartype_all # <-- you didn't sign up for this # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code - +import random import sys from api.utils.log_utils import initRootLogger -from graphrag.utils import get_llm_cache, set_llm_cache +from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NAME = "task_executor_" + CONSUMER_NO @@ -44,7 +44,7 @@ import numpy as np from peewee import DoesNotExist from api.db import LLMType, ParserType, TaskStatus -from api.db.services.dialog_service import keyword_extraction, question_proposal +from api.db.services.dialog_service import keyword_extraction, question_proposal, content_tagging from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService @@ -53,10 +53,10 @@ from api import settings from api.versions import get_ragflow_version from api.db.db_models import close_connection from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ - knowledge_graph, email + knowledge_graph, email, tag from rag.nlp import search, rag_tokenizer from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor -from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings +from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD from rag.utils import num_tokens_from_string from rag.utils.redis_conn import REDIS_CONN, Payload from rag.utils.storage_factory import STORAGE_IMPL @@ -78,7 +78,8 @@ FACTORY = { ParserType.ONE.value: one, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email, - ParserType.KG.value: knowledge_graph + ParserType.KG.value: knowledge_graph, + ParserType.TAG.value: tag } CONSUMER_NAME = "task_consumer_" + CONSUMER_NO @@ -199,7 +200,8 @@ def build_chunks(task, progress_callback): logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"])) except TimeoutError: progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") - logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"])) + logging.exception( + "Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"])) raise except Exception as e: if re.search("(No such file|not found)", str(e)): @@ -227,7 +229,7 @@ def build_chunks(task, progress_callback): "kb_id": str(task["kb_id"]) } if task["pagerank"]: - doc["pagerank_fea"] = int(task["pagerank"]) + doc[PAGERANK_FLD] = int(task["pagerank"]) el = 0 for ck in cks: d = copy.deepcopy(doc) @@ -252,7 +254,8 @@ def build_chunks(task, progress_callback): STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()) el += timer() - st except Exception: - logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) + logging.exception( + "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) raise d["img_id"] = "{}-{}".format(task["kb_id"], d["id"]) @@ -295,12 +298,43 @@ def build_chunks(task, progress_callback): d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st)) + if task["kb_parser_config"].get("tag_kb_ids", []): + progress_callback(msg="Start to tag for every chunk ...") + kb_ids = task["kb_parser_config"]["tag_kb_ids"] + tenant_id = task["tenant_id"] + topn_tags = task["kb_parser_config"].get("topn_tags", 3) + S = 1000 + st = timer() + examples = [] + all_tags = get_tags_from_cache(kb_ids) + if not all_tags: + all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S) + set_tags_to_cache(kb_ids, all_tags) + else: + all_tags = json.loads(all_tags) + + chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) + for d in docs: + if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S): + examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) + continue + cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags}) + if not cached: + cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags, + random.choices(examples, k=2) if len(examples)>2 else examples, + topn=topn_tags) + if cached: + set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags}) + d[TAG_FLD] = json.loads(cached) + + progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st)) + return docs def init_kb(row, vector_size: int): idxnm = search.index_name(row["tenant_id"]) - return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size) + return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) def embedding(docs, mdl, parser_config=None, callback=None): @@ -381,7 +415,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): "title_tks": rag_tokenizer.tokenize(row["name"]) } if row["pagerank"]: - doc["pagerank_fea"] = int(row["pagerank"]) + doc[PAGERANK_FLD] = int(row["pagerank"]) res = [] tk_count = 0 for content, vctr in chunks[original_length:]: @@ -480,7 +514,8 @@ def do_handle_task(task): doc_store_result = "" es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id) + doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), + task_dataset_id) if b % 128 == 0: progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") if doc_store_result: @@ -493,15 +528,21 @@ def do_handle_task(task): TaskService.update_chunk_ids(task["id"], chunk_ids_str) except DoesNotExist: logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.") - doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id) + doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), + task_dataset_id) return - logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts)) + logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, + task_to_page, len(chunks), + timer() - start_ts)) DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) time_cost = timer() - start_ts progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost)) - logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost)) + logging.info( + "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, + task_to_page, len(chunks), + token_count, time_cost)) def handle_task(): diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index e68e43735..8a4c1e09f 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -71,11 +71,13 @@ def findMaxTm(fnm): pass return m + tiktoken_cache_dir = get_project_base_directory() os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir # encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") encoder = tiktoken.get_encoding("cl100k_base") + def num_tokens_from_string(string: str) -> int: """Returns the number of tokens in a text string.""" try: diff --git a/rag/utils/doc_store_conn.py b/rag/utils/doc_store_conn.py index ffd4a245e..1e6e69c62 100644 --- a/rag/utils/doc_store_conn.py +++ b/rag/utils/doc_store_conn.py @@ -176,7 +176,17 @@ class DocStoreConnection(ABC): @abstractmethod def search( - self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str] + self, selectFields: list[str], + highlightFields: list[str], + condition: dict, + matchExprs: list[MatchExpr], + orderBy: OrderByExpr, + offset: int, + limit: int, + indexNames: str|list[str], + knowledgebaseIds: list[str], + aggFields: list[str] = [], + rank_feature: dict | None = None ) -> list[dict] | pl.DataFrame: """ Search with given conjunctive equivalent filtering condition and return all fields of matched documents @@ -191,7 +201,7 @@ class DocStoreConnection(ABC): raise NotImplementedError("Not implemented") @abstractmethod - def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]: + def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: """ Update or insert a bulk of rows """ diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 2a27247ad..1f9590045 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -9,6 +9,7 @@ from elasticsearch import Elasticsearch, NotFoundError from elasticsearch_dsl import UpdateByQuery, Q, Search, Index from elastic_transport import ConnectionTimeout from rag import settings +from rag.settings import TAG_FLD, PAGERANK_FLD from rag.utils import singleton from api.utils.file_utils import get_project_base_directory import polars as pl @@ -20,6 +21,7 @@ ATTEMPT_TIME = 2 logger = logging.getLogger('ragflow.es_conn') + @singleton class ESConnection(DocStoreConnection): def __init__(self): @@ -111,9 +113,19 @@ class ESConnection(DocStoreConnection): CRUD operations """ - def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], - orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str], - knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame: + def search( + self, selectFields: list[str], + highlightFields: list[str], + condition: dict, + matchExprs: list[MatchExpr], + orderBy: OrderByExpr, + offset: int, + limit: int, + indexNames: str | list[str], + knowledgebaseIds: list[str], + aggFields: list[str] = [], + rank_feature: dict | None = None + ) -> list[dict] | pl.DataFrame: """ Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html """ @@ -175,8 +187,13 @@ class ESConnection(DocStoreConnection): similarity=similarity, ) + if bqry and rank_feature: + for fld, sc in rank_feature.items(): + if fld != PAGERANK_FLD: + fld = f"{TAG_FLD}.{fld}" + bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) + if bqry: - bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10)) s = s.query(bqry) for field in highlightFields: s = s.highlight(field) @@ -187,7 +204,7 @@ class ESConnection(DocStoreConnection): order = "asc" if order == 0 else "desc" if field in ["page_num_int", "top_int"]: order_info = {"order": order, "unmapped_type": "float", - "mode": "avg", "numeric_type": "double"} + "mode": "avg", "numeric_type": "double"} elif field.endswith("_int") or field.endswith("_flt"): order_info = {"order": order, "unmapped_type": "float"} else: @@ -195,8 +212,11 @@ class ESConnection(DocStoreConnection): orders.append({field: order_info}) s = s.sort(*orders) + for fld in aggFields: + s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) + if limit > 0: - s = s[offset:offset+limit] + s = s[offset:offset + limit] q = s.to_dict() logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q)) @@ -240,7 +260,7 @@ class ESConnection(DocStoreConnection): logger.error("ESConnection.get timeout for 3 times!") raise Exception("ESConnection.get timeout.") - def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]: + def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html operations = [] for d in documents: @@ -292,44 +312,57 @@ class ESConnection(DocStoreConnection): if str(e).find("Timeout") > 0: continue return False - else: - # update unspecific maybe-multiple documents - bqry = Q("bool") - for k, v in condition.items(): - if not isinstance(k, str) or not v: - continue - if k == "exist": - bqry.filter.append(Q("exists", field=v)) - continue - if isinstance(v, list): - bqry.filter.append(Q("terms", **{k: v})) - elif isinstance(v, str) or isinstance(v, int): - bqry.filter.append(Q("term", **{k: v})) - else: - raise Exception( - f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") - scripts = [] - for k, v in newValue.items(): - if k == "remove": - scripts.append(f"ctx._source.remove('{v}');") - continue - if (not isinstance(k, str) or not v) and k != "available_int": - continue + + # update unspecific maybe-multiple documents + bqry = Q("bool") + for k, v in condition.items(): + if not isinstance(k, str) or not v: + continue + if k == "exist": + bqry.filter.append(Q("exists", field=v)) + continue + if isinstance(v, list): + bqry.filter.append(Q("terms", **{k: v})) + elif isinstance(v, str) or isinstance(v, int): + bqry.filter.append(Q("term", **{k: v})) + else: + raise Exception( + f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") + scripts = [] + params = {} + for k, v in newValue.items(): + if k == "remove": if isinstance(v, str): - scripts.append(f"ctx._source.{k} = '{v}'") - elif isinstance(v, int): - scripts.append(f"ctx._source.{k} = {v}") - else: - raise Exception( - f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") + scripts.append(f"ctx._source.remove('{v}');") + if isinstance(v, dict): + for kk, vv in v.items(): + scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);") + params[f"p_{kk}"] = vv + continue + if k == "add": + if isinstance(v, dict): + for kk, vv in v.items(): + scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});") + params[f"pp_{kk}"] = vv.strip() + continue + if (not isinstance(k, str) or not v) and k != "available_int": + continue + if isinstance(v, str): + scripts.append(f"ctx._source.{k} = '{v}'") + elif isinstance(v, int): + scripts.append(f"ctx._source.{k} = {v}") + else: + raise Exception( + f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") ubq = UpdateByQuery( index=indexName).using( self.es).query(bqry) - ubq = ubq.script(source="; ".join(scripts)) + ubq = ubq.script(source="".join(scripts), params=params) ubq = ubq.params(refresh=True) ubq = ubq.params(slices=5) ubq = ubq.params(conflicts="proceed") - for i in range(3): + + for _ in range(ATTEMPT_TIME): try: _ = ubq.execute() return True diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 38dd92484..dd11b9ece 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -10,6 +10,7 @@ from infinity.index import IndexInfo, IndexType from infinity.connection_pool import ConnectionPool from infinity.errors import ErrorCode from rag import settings +from rag.settings import PAGERANK_FLD from rag.utils import singleton import polars as pl from polars.series.series import Series @@ -231,8 +232,7 @@ class InfinityConnection(DocStoreConnection): """ def search( - self, - selectFields: list[str], + self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], @@ -241,7 +241,9 @@ class InfinityConnection(DocStoreConnection): limit: int, indexNames: str | list[str], knowledgebaseIds: list[str], - ) -> tuple[pl.DataFrame, int]: + aggFields: list[str] = [], + rank_feature: dict | None = None + ) -> list[dict] | pl.DataFrame: """ TODO: Infinity doesn't provide highlight """ @@ -256,7 +258,7 @@ class InfinityConnection(DocStoreConnection): if essential_field not in selectFields: selectFields.append(essential_field) if matchExprs: - for essential_field in ["score()", "pagerank_fea"]: + for essential_field in ["score()", PAGERANK_FLD]: selectFields.append(essential_field) # Prepare expressions common to all tables @@ -346,7 +348,7 @@ class InfinityConnection(DocStoreConnection): self.connPool.release_conn(inf_conn) res = concat_dataframes(df_list, selectFields) if matchExprs: - res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True) + res = res.sort(pl.col("SCORE") + pl.col(PAGERANK_FLD), descending=True, maintain_order=True) res = res.limit(limit) logger.debug(f"INFINITY search final result: {str(res)}") return res, total_hits_count @@ -378,7 +380,7 @@ class InfinityConnection(DocStoreConnection): return res_fields.get(chunkId, None) def insert( - self, documents: list[dict], indexName: str, knowledgebaseId: str + self, documents: list[dict], indexName: str, knowledgebaseId: str = None ) -> list[str]: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) @@ -456,7 +458,7 @@ class InfinityConnection(DocStoreConnection): elif k in ["page_num_int", "top_int"]: assert isinstance(v, list) newValue[k] = "_".join(f"{num:08x}" for num in v) - elif k == "remove" and v in ["pagerank_fea"]: + elif k == "remove" and v in [PAGERANK_FLD]: del newValue[k] newValue[v] = 0 logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.") diff --git a/sdk/python/test/test_sdk_api/t_dataset.py b/sdk/python/test/test_sdk_api/t_dataset.py index 183904bc9..288474409 100644 --- a/sdk/python/test/test_sdk_api/t_dataset.py +++ b/sdk/python/test/test_sdk_api/t_dataset.py @@ -27,7 +27,7 @@ def test_create_dataset_with_invalid_parameter(get_api_key_fixture): API_KEY = get_api_key_fixture rag = RAGFlow(API_KEY, HOST_ADDRESS) valid_chunk_methods = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", - "knowledge_graph", "email"] + "knowledge_graph", "email", "tag"] chunk_method = "invalid_chunk_method" with pytest.raises(Exception) as exc_info: rag.create_dataset("test_create_dataset_with_invalid_chunk_method",chunk_method=chunk_method)