mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
Add pagerank to KB. (#3809)
### What problem does this PR solve? #3794 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
7543047de3
commit
74b28ef1b0
@ -227,12 +227,18 @@ def create():
|
|||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
d["kb_id"] = [doc.kb_id]
|
d["kb_id"] = [doc.kb_id]
|
||||||
d["docnm_kwd"] = doc.name
|
d["docnm_kwd"] = doc.name
|
||||||
|
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
||||||
d["doc_id"] = doc.id
|
d["doc_id"] = doc.id
|
||||||
|
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
|
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
if kb.pagerank: d["pagerank_fea"] = kb.pagerank
|
||||||
|
|
||||||
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
||||||
|
|
||||||
|
@ -102,6 +102,14 @@ def update():
|
|||||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||||
return get_data_error_result()
|
return get_data_error_result()
|
||||||
|
|
||||||
|
if kb.pagerank != req.get("pagerank", 0):
|
||||||
|
if req.get("pagerank", 0) > 0:
|
||||||
|
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
|
||||||
|
search.index_name(kb.tenant_id), kb.id)
|
||||||
|
else:
|
||||||
|
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
|
||||||
|
search.index_name(kb.tenant_id), kb.id)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
|
@ -703,6 +703,7 @@ class Knowledgebase(DataBaseModel):
|
|||||||
default=ParserType.NAIVE.value,
|
default=ParserType.NAIVE.value,
|
||||||
index=True)
|
index=True)
|
||||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||||
|
pagerank = IntegerField(default=0, index=False)
|
||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
@ -1076,4 +1077,10 @@ def migrate_db():
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(
|
||||||
|
migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False))
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ -104,7 +104,8 @@ class KnowledgebaseService(CommonService):
|
|||||||
cls.model.token_num,
|
cls.model.token_num,
|
||||||
cls.model.chunk_num,
|
cls.model.chunk_num,
|
||||||
cls.model.parser_id,
|
cls.model.parser_id,
|
||||||
cls.model.parser_config]
|
cls.model.parser_config,
|
||||||
|
cls.model.pagerank]
|
||||||
kbs = cls.model.select(*fields).join(Tenant, on=(
|
kbs = cls.model.select(*fields).join(Tenant, on=(
|
||||||
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
||||||
(cls.model.id == kb_id),
|
(cls.model.id == kb_id),
|
||||||
|
@ -191,15 +191,18 @@ class TenantLLMService(CommonService):
|
|||||||
|
|
||||||
num = 0
|
num = 0
|
||||||
try:
|
try:
|
||||||
|
if llm_factory:
|
||||||
|
tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory)
|
||||||
|
else:
|
||||||
tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name)
|
tenant_llms = cls.query(tenant_id=tenant_id, llm_name=llm_name)
|
||||||
if tenant_llms:
|
if not tenant_llms:
|
||||||
|
if not llm_factory: llm_factory = mdlnm
|
||||||
|
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
||||||
|
else:
|
||||||
tenant_llm = tenant_llms[0]
|
tenant_llm = tenant_llms[0]
|
||||||
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
|
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
|
||||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
||||||
.execute()
|
.execute()
|
||||||
else:
|
|
||||||
if not llm_factory: llm_factory = mdlnm
|
|
||||||
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("TenantLLMService.increase_usage got exception")
|
logging.exception("TenantLLMService.increase_usage got exception")
|
||||||
return num
|
return num
|
||||||
|
@ -53,6 +53,7 @@ class TaskService(CommonService):
|
|||||||
Knowledgebase.tenant_id,
|
Knowledgebase.tenant_id,
|
||||||
Knowledgebase.language,
|
Knowledgebase.language,
|
||||||
Knowledgebase.embd_id,
|
Knowledgebase.embd_id,
|
||||||
|
Knowledgebase.pagerank,
|
||||||
Tenant.img2txt_id,
|
Tenant.img2txt_id,
|
||||||
Tenant.asr_id,
|
Tenant.asr_id,
|
||||||
Tenant.llm_id,
|
Tenant.llm_id,
|
||||||
|
@ -22,5 +22,6 @@
|
|||||||
"rank_int": {"type": "integer", "default": 0},
|
"rank_int": {"type": "integer", "default": 0},
|
||||||
"available_int": {"type": "integer", "default": 1},
|
"available_int": {"type": "integer", "default": 1},
|
||||||
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
|
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
|
||||||
"entities_kwd": {"type": "varchar", "default": ""}
|
"entities_kwd": {"type": "varchar", "default": ""},
|
||||||
|
"pagerank_fea": {"type": "integer", "default": 0}
|
||||||
}
|
}
|
||||||
|
@ -75,7 +75,7 @@ class Dealer:
|
|||||||
|
|
||||||
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
||||||
"doc_id", "position_list", "knowledge_graph_kwd",
|
"doc_id", "position_list", "knowledge_graph_kwd",
|
||||||
"available_int", "content_with_weight"])
|
"available_int", "content_with_weight", "pagerank_fea"])
|
||||||
kwds = set([])
|
kwds = set([])
|
||||||
|
|
||||||
qst = req.get("question", "")
|
qst = req.get("question", "")
|
||||||
@ -234,11 +234,13 @@ class Dealer:
|
|||||||
vector_column = f"q_{vector_size}_vec"
|
vector_column = f"q_{vector_size}_vec"
|
||||||
zero_vector = [0.0] * vector_size
|
zero_vector = [0.0] * vector_size
|
||||||
ins_embd = []
|
ins_embd = []
|
||||||
|
pageranks = []
|
||||||
for chunk_id in sres.ids:
|
for chunk_id in sres.ids:
|
||||||
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
||||||
if isinstance(vector, str):
|
if isinstance(vector, str):
|
||||||
vector = [float(v) for v in vector.split("\t")]
|
vector = [float(v) for v in vector.split("\t")]
|
||||||
ins_embd.append(vector)
|
ins_embd.append(vector)
|
||||||
|
pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
|
||||||
if not ins_embd:
|
if not ins_embd:
|
||||||
return [], [], []
|
return [], [], []
|
||||||
|
|
||||||
@ -257,7 +259,8 @@ class Dealer:
|
|||||||
ins_embd,
|
ins_embd,
|
||||||
keywords,
|
keywords,
|
||||||
ins_tw, tkweight, vtweight)
|
ins_tw, tkweight, vtweight)
|
||||||
return sim, tksim, vtsim
|
|
||||||
|
return sim+np.array(pageranks, dtype=float), tksim, vtsim
|
||||||
|
|
||||||
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
||||||
vtweight=0.7, cfield="content_ltks"):
|
vtweight=0.7, cfield="content_ltks"):
|
||||||
@ -351,7 +354,7 @@ class Dealer:
|
|||||||
"vector": chunk.get(vector_column, zero_vector),
|
"vector": chunk.get(vector_column, zero_vector),
|
||||||
"positions": json.loads(position_list)
|
"positions": json.loads(position_list)
|
||||||
}
|
}
|
||||||
if highlight:
|
if highlight and sres.highlight:
|
||||||
if id in sres.highlight:
|
if id in sres.highlight:
|
||||||
d["highlight"] = rmSpace(sres.highlight[id])
|
d["highlight"] = rmSpace(sres.highlight[id])
|
||||||
else:
|
else:
|
||||||
|
@ -201,6 +201,7 @@ def build_chunks(task, progress_callback):
|
|||||||
"doc_id": task["doc_id"],
|
"doc_id": task["doc_id"],
|
||||||
"kb_id": str(task["kb_id"])
|
"kb_id": str(task["kb_id"])
|
||||||
}
|
}
|
||||||
|
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
|
||||||
el = 0
|
el = 0
|
||||||
for ck in cks:
|
for ck in cks:
|
||||||
d = copy.deepcopy(doc)
|
d = copy.deepcopy(doc)
|
||||||
@ -339,6 +340,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|||||||
"docnm_kwd": row["name"],
|
"docnm_kwd": row["name"],
|
||||||
"title_tks": rag_tokenizer.tokenize(row["name"])
|
"title_tks": rag_tokenizer.tokenize(row["name"])
|
||||||
}
|
}
|
||||||
|
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
|
||||||
res = []
|
res = []
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
for content, vctr in chunks[original_length:]:
|
for content, vctr in chunks[original_length:]:
|
||||||
@ -431,7 +433,7 @@ def do_handle_task(task):
|
|||||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||||
logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
|
logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
|
||||||
if doc_store_result:
|
if doc_store_result:
|
||||||
error_message = "Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||||
progress_callback(-1, msg=error_message)
|
progress_callback(-1, msg=error_message)
|
||||||
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
|
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
|
@ -175,6 +175,7 @@ class ESConnection(DocStoreConnection):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if bqry:
|
if bqry:
|
||||||
|
bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
|
||||||
s = s.query(bqry)
|
s = s.query(bqry)
|
||||||
for field in highlightFields:
|
for field in highlightFields:
|
||||||
s = s.highlight(field)
|
s = s.highlight(field)
|
||||||
@ -283,12 +284,16 @@ class ESConnection(DocStoreConnection):
|
|||||||
f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||||
if str(e).find("Timeout") > 0:
|
if str(e).find("Timeout") > 0:
|
||||||
continue
|
continue
|
||||||
|
return False
|
||||||
else:
|
else:
|
||||||
# update unspecific maybe-multiple documents
|
# update unspecific maybe-multiple documents
|
||||||
bqry = Q("bool")
|
bqry = Q("bool")
|
||||||
for k, v in condition.items():
|
for k, v in condition.items():
|
||||||
if not isinstance(k, str) or not v:
|
if not isinstance(k, str) or not v:
|
||||||
continue
|
continue
|
||||||
|
if k == "exist":
|
||||||
|
bqry.filter.append(Q("exists", field=v))
|
||||||
|
continue
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
bqry.filter.append(Q("terms", **{k: v}))
|
bqry.filter.append(Q("terms", **{k: v}))
|
||||||
elif isinstance(v, str) or isinstance(v, int):
|
elif isinstance(v, str) or isinstance(v, int):
|
||||||
@ -298,6 +303,9 @@ class ESConnection(DocStoreConnection):
|
|||||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||||
scripts = []
|
scripts = []
|
||||||
for k, v in newValue.items():
|
for k, v in newValue.items():
|
||||||
|
if k == "remove":
|
||||||
|
scripts.append(f"ctx._source.remove('{v}');")
|
||||||
|
continue
|
||||||
if (not isinstance(k, str) or not v) and k != "available_int":
|
if (not isinstance(k, str) or not v) and k != "available_int":
|
||||||
continue
|
continue
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
|
@ -21,6 +21,7 @@ class DataSet(Base):
|
|||||||
self.chunk_count = 0
|
self.chunk_count = 0
|
||||||
self.chunk_method = "naive"
|
self.chunk_method = "naive"
|
||||||
self.parser_config = None
|
self.parser_config = None
|
||||||
|
self.pagerank = 0
|
||||||
for k in list(res_dict.keys()):
|
for k in list(res_dict.keys()):
|
||||||
if k not in self.__dict__:
|
if k not in self.__dict__:
|
||||||
res_dict.pop(k)
|
res_dict.pop(k)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user