add search TAB backend api (#2375)

### What problem does this PR solve?
 #2247

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2024-09-11 19:49:18 +08:00 committed by GitHub
parent 8052cbc70e
commit 333608a1d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 117 additions and 16 deletions

View File

@ -58,7 +58,7 @@ def list_chunk():
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id))
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@ -259,12 +259,25 @@ def retrieval_test():
size = int(req.get("size", 30))
question = req["question"]
kb_id = req["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id]
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
tenants = UserTenantService.query(user_id=current_user.id)
for kid in kb_id:
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kid):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")
@ -281,9 +294,9 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]

View File

@ -14,19 +14,22 @@
# limitations under the License.
#
import json
import re
from copy import deepcopy
from db.services.user_service import UserTenantService
from api.db.services.user_service import UserTenantService
from flask import request, Response
from flask_login import login_required, current_user
from api.db import LLMType
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.db.services.llm_service import LLMBundle, TenantService
from api.settings import RetCode
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api.settings import RetCode, retrievaler
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor
@manager.route('/set', methods=['POST'])
@ -286,3 +289,86 @@ def thumbup():
ConversationService.update_by_id(conv["id"], conv)
return get_json_result(data=conv)
@manager.route('/ask', methods=['POST'])
@login_required
@validate_request("question", "kb_ids")
def ask_about():
req = request.json
uid = current_user.id
def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
@manager.route('/mindmap', methods=['POST'])
@login_required
@validate_request("question", "kb_ids")
def mindmap():
req = request.json
kb_ids = req["kb_ids"]
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False)
mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
return get_json_result(data=mind_map)
@manager.route('/related_questions', methods=['POST'])
@login_required
@validate_request("question")
def related_questions():
req = request.json
question = req["question"]
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
prompt = """
Objective: To generate search terms related to the user's search keywords, helping users find more valuable information.
Instructions:
- Based on the keywords provided by the user, generate 5-10 related search terms.
- Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information.
- Use common, general terms as much as possible, avoiding obscure words or technical jargon.
- Keep the term length between 2-4 words, concise and clear.
- DO NOT translate, use the language of the original keywords.
### Example:
Keywords: Chinese football
Related search terms:
1. Current status of Chinese football
2. Reform of Chinese football
3. Youth training of Chinese football
4. Chinese football in the Asian Cup
5. Chinese football in the World Cup
Reason:
- When searching, users often only use one or two keywords, making it difficult to fully express their information needs.
- Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
Keywords: {question}
Related search terms:
"""}], {"temperature": 0.9})
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])

View File

@ -210,7 +210,7 @@ def chat(dialog, messages, stream=True, **kwargs):
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
done_tm = timer()
prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000)
return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", "<br/>", prompt)}
return {"answer": answer, "reference": refs, "prompt": prompt}
if stream:
last_ans = ""

View File

@ -190,7 +190,7 @@ class LLMBundle(object):
tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find mole for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
self.max_length = 512
self.max_length = 8192
for lm in LLMService.query(llm_name=llm_name):
self.max_length = lm.max_tokens
break

View File

@ -23,7 +23,7 @@ from rag.nlp.search import Dealer
class KGSearch(Dealer):
def search(self, req, idxnm, emb_mdl=None):
def search(self, req, idxnm, emb_mdl=None, highlight=False):
def merge_into_first(sres, title=""):
df,texts = [],[]
for d in sres["hits"]["hits"]:

View File

@ -79,9 +79,9 @@ class Dealer:
Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry
def search(self, req, idxnm, emb_mdl=None):
def search(self, req, idxnm, emb_mdl=None, highlight=False):
qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst)
bqry, keywords = self.qryr.question(qst, min_match="30%")
bqry = self._add_filters(bqry, req)
bqry.boost = 0.05
@ -130,7 +130,7 @@ class Dealer:
qst, emb_mdl, req.get(
"similarity", 0.1), topk)
s["knn"]["filter"] = bqry.to_dict()
if "highlight" in s:
if not highlight and "highlight" in s:
del s["highlight"]
q_vec = s["knn"]["query_vector"]
es_logger.info("【Q】: {}".format(json.dumps(s)))
@ -356,7 +356,7 @@ class Dealer:
rag_tokenizer.tokenize(inst).split(" "))
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None):
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks
@ -364,7 +364,7 @@ class Dealer:
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}
sres = self.search(req, index_name(tenant_id), embd_mdl)
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
if rerank_mdl:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
@ -405,6 +405,8 @@ class Dealer:
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
"positions": sres.field[id].get("position_int", "").split("\t")
}
if highlight:
d["highlight"] = rmSpace(sres.highlight[id])
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):