diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index f4e1b67c2..301445c74 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -58,7 +58,7 @@ def list_chunk(): } if "available_int" in req: query["available_int"] = int(req["available_int"]) - sres = retrievaler.search(query, search.index_name(tenant_id)) + sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} for id in sres.ids: d = { @@ -259,12 +259,25 @@ def retrieval_test(): size = int(req.get("size", 30)) question = req["question"] kb_id = req["kb_id"] + if isinstance(kb_id, str): kb_id = [kb_id] doc_ids = req.get("doc_ids", []) similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) + try: - e, kb = KnowledgebaseService.get_by_id(kb_id) + tenants = UserTenantService.query(user_id=current_user.id) + for kid in kb_id: + for tenant in tenants: + if KnowledgebaseService.query( + tenant_id=tenant.tenant_id, id=kid): + break + else: + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + + e, kb = KnowledgebaseService.get_by_id(kb_id[0]) if not e: return get_data_error_result(retmsg="Knowledgebase not found!") @@ -281,9 +294,9 @@ def retrieval_test(): question += keyword_extraction(chat_mdl, question) retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler - ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, + ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size, similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl) + doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) for c in ranks["chunks"]: if "vector" in c: del c["vector"] diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index c3e05e87b..01cbbd9db 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -14,19 +14,22 @@ # limitations under the License. # import json +import re from copy import deepcopy -from db.services.user_service import UserTenantService +from api.db.services.user_service import UserTenantService from flask import request, Response from flask_login import login_required, current_user from api.db import LLMType -from api.db.services.dialog_service import DialogService, ConversationService, chat -from api.db.services.llm_service import LLMBundle, TenantService -from api.settings import RetCode +from api.db.services.dialog_service import DialogService, ConversationService, chat, ask +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService +from api.settings import RetCode, retrievaler from api.utils import get_uuid from api.utils.api_utils import get_json_result from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from graphrag.mind_map_extractor import MindMapExtractor @manager.route('/set', methods=['POST']) @@ -286,3 +289,86 @@ def thumbup(): ConversationService.update_by_id(conv["id"], conv) return get_json_result(data=conv) + + +@manager.route('/ask', methods=['POST']) +@login_required +@validate_request("question", "kb_ids") +def ask_about(): + req = request.json + uid = current_user.id + def stream(): + nonlocal req, uid + try: + for ans in ask(req["question"], req["kb_ids"], uid): + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" + except Exception as e: + yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" + + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + +@manager.route('/mindmap', methods=['POST']) +@login_required +@validate_request("question", "kb_ids") +def mindmap(): + req = request.json + kb_ids = req["kb_ids"] + e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) + if not e: + return get_data_error_result(retmsg="Knowledgebase not found!") + + embd_mdl = TenantLLMService.model_instance( + kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) + chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) + ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, + 0.3, 0.3, aggs=False) + mindmap = MindMapExtractor(chat_mdl) + mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output + return get_json_result(data=mind_map) + + +@manager.route('/related_questions', methods=['POST']) +@login_required +@validate_request("question") +def related_questions(): + req = request.json + question = req["question"] + chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) + prompt = """ +Objective: To generate search terms related to the user's search keywords, helping users find more valuable information. +Instructions: + - Based on the keywords provided by the user, generate 5-10 related search terms. + - Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information. + - Use common, general terms as much as possible, avoiding obscure words or technical jargon. + - Keep the term length between 2-4 words, concise and clear. + - DO NOT translate, use the language of the original keywords. + +### Example: +Keywords: Chinese football +Related search terms: +1. Current status of Chinese football +2. Reform of Chinese football +3. Youth training of Chinese football +4. Chinese football in the Asian Cup +5. Chinese football in the World Cup + +Reason: + - When searching, users often only use one or two keywords, making it difficult to fully express their information needs. + - Generating related search terms can help users dig deeper into relevant information and improve search efficiency. + - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. + +""" + ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" +Keywords: {question} +Related search terms: + """}], {"temperature": 0.9}) + return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 7f9fc84b4..105dd37be 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -210,7 +210,7 @@ def chat(dialog, messages, stream=True, **kwargs): answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" done_tm = timer() prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000) - return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", "
", prompt)} + return {"answer": answer, "reference": refs, "prompt": prompt} if stream: last_ans = "" diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 9a548e728..0309e3858 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -190,7 +190,7 @@ class LLMBundle(object): tenant_id, llm_type, llm_name, lang=lang) assert self.mdl, "Can't find mole for {}/{}/{}".format( tenant_id, llm_type, llm_name) - self.max_length = 512 + self.max_length = 8192 for lm in LLMService.query(llm_name=llm_name): self.max_length = lm.max_tokens break diff --git a/graphrag/search.py b/graphrag/search.py index fb4453408..85ba0698a 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -23,7 +23,7 @@ from rag.nlp.search import Dealer class KGSearch(Dealer): - def search(self, req, idxnm, emb_mdl=None): + def search(self, req, idxnm, emb_mdl=None, highlight=False): def merge_into_first(sres, title=""): df,texts = [],[] for d in sres["hits"]["hits"]: diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 0ddb87ed4..d72580cfd 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -79,9 +79,9 @@ class Dealer: Q("bool", must_not=Q("range", available_int={"lt": 1}))) return bqry - def search(self, req, idxnm, emb_mdl=None): + def search(self, req, idxnm, emb_mdl=None, highlight=False): qst = req.get("question", "") - bqry, keywords = self.qryr.question(qst) + bqry, keywords = self.qryr.question(qst, min_match="30%") bqry = self._add_filters(bqry, req) bqry.boost = 0.05 @@ -130,7 +130,7 @@ class Dealer: qst, emb_mdl, req.get( "similarity", 0.1), topk) s["knn"]["filter"] = bqry.to_dict() - if "highlight" in s: + if not highlight and "highlight" in s: del s["highlight"] q_vec = s["knn"]["query_vector"] es_logger.info("【Q】: {}".format(json.dumps(s))) @@ -356,7 +356,7 @@ class Dealer: rag_tokenizer.tokenize(inst).split(" ")) def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, - vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None): + vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False): ranks = {"total": 0, "chunks": [], "doc_aggs": {}} if not question: return ranks @@ -364,7 +364,7 @@ class Dealer: "question": question, "vector": True, "topk": top, "similarity": similarity_threshold, "available_int": 1} - sres = self.search(req, index_name(tenant_id), embd_mdl) + sres = self.search(req, index_name(tenant_id), embd_mdl, highlight) if rerank_mdl: sim, tsim, vsim = self.rerank_by_model(rerank_mdl, @@ -405,6 +405,8 @@ class Dealer: "vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))), "positions": sres.field[id].get("position_int", "").split("\t") } + if highlight: + d["highlight"] = rmSpace(sres.highlight[id]) if len(d["positions"]) % 5 == 0: poss = [] for i in range(0, len(d["positions"]), 5):