mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-05-15 11:08:18 +08:00

### What problem does this PR solve? This PR solves the problem that when you delete selected chunks, the chunks can be deleted but Chunk Number doesn't change. Now you delete one chunk, the Chunk Number is reduced by one, delete two chunks, the Chunk Number is reduced by two... #900 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Signed-off-by: seaver <zhudan187@qq.com>
293 lines
12 KiB
Python
293 lines
12 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import datetime
|
|
|
|
from flask import request
|
|
from flask_login import login_required, current_user
|
|
from elasticsearch_dsl import Q
|
|
|
|
from rag.app.qa import rmPrefix, beAdoc
|
|
from rag.nlp import search, rag_tokenizer, keyword_extraction
|
|
from rag.utils.es_conn import ELASTICSEARCH
|
|
from rag.utils import rmSpace
|
|
from api.db import LLMType, ParserType
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
from api.db.services.llm_service import TenantLLMService
|
|
from api.db.services.user_service import UserTenantService
|
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
|
from api.db.services.document_service import DocumentService
|
|
from api.settings import RetCode, retrievaler
|
|
from api.utils.api_utils import get_json_result
|
|
import hashlib
|
|
import re
|
|
|
|
|
|
@manager.route('/list', methods=['POST'])
|
|
@login_required
|
|
@validate_request("doc_id")
|
|
def list_chunk():
|
|
req = request.json
|
|
doc_id = req["doc_id"]
|
|
page = int(req.get("page", 1))
|
|
size = int(req.get("size", 30))
|
|
question = req.get("keywords", "")
|
|
try:
|
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
|
if not tenant_id:
|
|
return get_data_error_result(retmsg="Tenant not found!")
|
|
e, doc = DocumentService.get_by_id(doc_id)
|
|
if not e:
|
|
return get_data_error_result(retmsg="Document not found!")
|
|
query = {
|
|
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
|
}
|
|
if "available_int" in req:
|
|
query["available_int"] = int(req["available_int"])
|
|
sres = retrievaler.search(query, search.index_name(tenant_id))
|
|
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
|
for id in sres.ids:
|
|
d = {
|
|
"chunk_id": id,
|
|
"content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get(
|
|
"content_with_weight", ""),
|
|
"doc_id": sres.field[id]["doc_id"],
|
|
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
|
"important_kwd": sres.field[id].get("important_kwd", []),
|
|
"img_id": sres.field[id].get("img_id", ""),
|
|
"available_int": sres.field[id].get("available_int", 1),
|
|
"positions": sres.field[id].get("position_int", "").split("\t")
|
|
}
|
|
if len(d["positions"]) % 5 == 0:
|
|
poss = []
|
|
for i in range(0, len(d["positions"]), 5):
|
|
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
|
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
|
d["positions"] = poss
|
|
res["chunks"].append(d)
|
|
return get_json_result(data=res)
|
|
except Exception as e:
|
|
if str(e).find("not_found") > 0:
|
|
return get_json_result(data=False, retmsg=f'No chunk found!',
|
|
retcode=RetCode.DATA_ERROR)
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route('/get', methods=['GET'])
|
|
@login_required
|
|
def get():
|
|
chunk_id = request.args["chunk_id"]
|
|
try:
|
|
tenants = UserTenantService.query(user_id=current_user.id)
|
|
if not tenants:
|
|
return get_data_error_result(retmsg="Tenant not found!")
|
|
res = ELASTICSEARCH.get(
|
|
chunk_id, search.index_name(
|
|
tenants[0].tenant_id))
|
|
if not res.get("found"):
|
|
return server_error_response("Chunk not found")
|
|
id = res["_id"]
|
|
res = res["_source"]
|
|
res["chunk_id"] = id
|
|
k = []
|
|
for n in res.keys():
|
|
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
|
k.append(n)
|
|
for n in k:
|
|
del res[n]
|
|
|
|
return get_json_result(data=res)
|
|
except Exception as e:
|
|
if str(e).find("NotFoundError") >= 0:
|
|
return get_json_result(data=False, retmsg=f'Chunk not found!',
|
|
retcode=RetCode.DATA_ERROR)
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route('/set', methods=['POST'])
|
|
@login_required
|
|
@validate_request("doc_id", "chunk_id", "content_with_weight",
|
|
"important_kwd")
|
|
def set():
|
|
req = request.json
|
|
d = {
|
|
"id": req["chunk_id"],
|
|
"content_with_weight": req["content_with_weight"]}
|
|
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
|
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
|
d["important_kwd"] = req["important_kwd"]
|
|
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
|
|
if "available_int" in req:
|
|
d["available_int"] = req["available_int"]
|
|
|
|
try:
|
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
|
if not tenant_id:
|
|
return get_data_error_result(retmsg="Tenant not found!")
|
|
|
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
|
embd_mdl = TenantLLMService.model_instance(
|
|
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
|
|
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
if not e:
|
|
return get_data_error_result(retmsg="Document not found!")
|
|
|
|
if doc.parser_id == ParserType.QA:
|
|
arr = [
|
|
t for t in re.split(
|
|
r"[\n\t]",
|
|
req["content_with_weight"]) if len(t) > 1]
|
|
if len(arr) != 2:
|
|
return get_data_error_result(
|
|
retmsg="Q&A must be separated by TAB/ENTER key.")
|
|
q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
|
|
d = beAdoc(d, arr[0], arr[1], not any(
|
|
[rag_tokenizer.is_chinese(t) for t in q + a]))
|
|
|
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
|
d["q_%d_vec" % len(v)] = v.tolist()
|
|
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
|
return get_json_result(data=True)
|
|
except Exception as e:
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route('/switch', methods=['POST'])
|
|
@login_required
|
|
@validate_request("chunk_ids", "available_int", "doc_id")
|
|
def switch():
|
|
req = request.json
|
|
try:
|
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
|
if not tenant_id:
|
|
return get_data_error_result(retmsg="Tenant not found!")
|
|
if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
|
|
search.index_name(tenant_id)):
|
|
return get_data_error_result(retmsg="Index updating failure")
|
|
return get_json_result(data=True)
|
|
except Exception as e:
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route('/rm', methods=['POST'])
|
|
@login_required
|
|
@validate_request("chunk_ids","doc_id")
|
|
def rm():
|
|
req = request.json
|
|
try:
|
|
if not ELASTICSEARCH.deleteByQuery(
|
|
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
|
return get_data_error_result(retmsg="Index updating failure")
|
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
if not e:
|
|
return get_data_error_result(retmsg="Document not found!")
|
|
deleted_chunk_ids = req["chunk_ids"]
|
|
chunk_number = len(deleted_chunk_ids)
|
|
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
|
return get_json_result(data=True)
|
|
except Exception as e:
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route('/create', methods=['POST'])
|
|
@login_required
|
|
@validate_request("doc_id", "content_with_weight")
|
|
def create():
|
|
req = request.json
|
|
md5 = hashlib.md5()
|
|
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
|
|
chunck_id = md5.hexdigest()
|
|
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
|
|
"content_with_weight": req["content_with_weight"]}
|
|
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
|
d["important_kwd"] = req.get("important_kwd", [])
|
|
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", [])))
|
|
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
|
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
|
|
|
try:
|
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
|
if not e:
|
|
return get_data_error_result(retmsg="Document not found!")
|
|
d["kb_id"] = [doc.kb_id]
|
|
d["docnm_kwd"] = doc.name
|
|
d["doc_id"] = doc.id
|
|
|
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
|
if not tenant_id:
|
|
return get_data_error_result(retmsg="Tenant not found!")
|
|
|
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
|
embd_mdl = TenantLLMService.model_instance(
|
|
tenant_id, LLMType.EMBEDDING.value, embd_id)
|
|
|
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
|
v = 0.1 * v[0] + 0.9 * v[1]
|
|
d["q_%d_vec" % len(v)] = v.tolist()
|
|
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
|
|
|
DocumentService.increment_chunk_num(
|
|
doc.id, doc.kb_id, c, 1, 0)
|
|
return get_json_result(data={"chunk_id": chunck_id})
|
|
except Exception as e:
|
|
return server_error_response(e)
|
|
|
|
|
|
@manager.route('/retrieval_test', methods=['POST'])
|
|
@login_required
|
|
@validate_request("kb_id", "question")
|
|
def retrieval_test():
|
|
req = request.json
|
|
page = int(req.get("page", 1))
|
|
size = int(req.get("size", 30))
|
|
question = req["question"]
|
|
kb_id = req["kb_id"]
|
|
doc_ids = req.get("doc_ids", [])
|
|
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
|
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
|
top = int(req.get("top_k", 1024))
|
|
try:
|
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
|
if not e:
|
|
return get_data_error_result(retmsg="Knowledgebase not found!")
|
|
|
|
embd_mdl = TenantLLMService.model_instance(
|
|
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
|
|
|
rerank_mdl = None
|
|
if req.get("rerank_id"):
|
|
rerank_mdl = TenantLLMService.model_instance(
|
|
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
|
|
|
if req.get("keyword", False):
|
|
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
|
question += keyword_extraction(chat_mdl, question)
|
|
|
|
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
|
similarity_threshold, vector_similarity_weight, top,
|
|
doc_ids, rerank_mdl=rerank_mdl)
|
|
for c in ranks["chunks"]:
|
|
if "vector" in c:
|
|
del c["vector"]
|
|
|
|
return get_json_result(data=ranks)
|
|
except Exception as e:
|
|
if str(e).find("not_found") > 0:
|
|
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
|
retcode=RetCode.DATA_ERROR)
|
|
return server_error_response(e)
|