Light GraphRAG (#4585)

### What problem does this PR solve?

#4543

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2025-01-22 19:43:14 +08:00 committed by GitHub
parent 1a367664f1
commit dd0ebbea35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 5461 additions and 4000 deletions

View File

@ -155,7 +155,7 @@ def set():
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
d = beAdoc(d, arr[0], arr[1], not any(
d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
@ -270,6 +270,7 @@ def retrieval_test():
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
tenant_ids = []
@ -301,12 +302,20 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
rank_feature=labels
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
tenant_ids,
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
for c in ranks["chunks"]:
c.pop("vector", None)
ranks["labels"] = labels

View File

@ -31,7 +31,7 @@ from api.db.services.llm_service import LLMBundle, TenantService
from api import settings
from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor
from graphrag.general.mind_map_extractor import MindMapExtractor
@manager.route('/set', methods=['POST']) # noqa: F821

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from flask import request
from flask_login import login_required, current_user
@ -272,4 +274,36 @@ def rename_tags(kb_id):
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)
return get_json_result(data=True)
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
@login_required
def knowledge_graph(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
req = {
"kb_id": [kb_id],
"knowledge_graph_kwd": ["graph"]
}
sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id])
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:1]:
ty = sres.field[id]["knowledge_graph_kwd"]
try:
content_json = json.loads(sres.field[id]["content_with_weight"])
except Exception:
continue
obj[ty] = content_json
if "nodes" in obj["graph"]:
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
if "edges" in obj["graph"]:
obj["graph"]["edges"] = sorted(obj["graph"]["edges"], key=lambda x: x.get("weight", 0), reverse=True)[:128]
return get_json_result(data=obj)

View File

@ -15,7 +15,7 @@
#
from flask import request, jsonify
from api.db import LLMType, ParserType
from api.db import LLMType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
@ -30,6 +30,7 @@ def retrieval(tenant_id):
req = request.json
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)
retrieval_setting = req.get("retrieval_setting", {})
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
@ -45,8 +46,7 @@ def retrieval(tenant_id):
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
ranks = settings.retrievaler.retrieval(
question,
embd_mdl,
kb.tenant_id,
@ -58,6 +58,16 @@ def retrieval(tenant_id):
top=top,
rank_feature=label_question(question, [kb])
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
records = []
for c in ranks["chunks"]:
c.pop("vector", None)

View File

@ -1297,15 +1297,15 @@ def retrieval_test(tenant_id):
kb_ids = req["dataset_ids"]
if not isinstance(kb_ids, list):
return get_error_data_result("`dataset_ids` should be a list")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
for id in kb_ids:
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
return get_error_data_result(f"You don't own the dataset {id}.")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_result(
message='Datasets use different embedding models."',
code=settings.RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.DATA_ERROR,
)
if "question" not in req:
return get_error_data_result("`question` is required.")
@ -1313,6 +1313,7 @@ def retrieval_test(tenant_id):
size = int(req.get("page_size", 30))
question = req["question"]
doc_ids = req.get("document_ids", [])
use_kg = req.get("use_kg", False)
if not isinstance(doc_ids, list):
return get_error_data_result("`documents` should be a list")
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
@ -1342,8 +1343,7 @@ def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
ranks = settings.retrievaler.retrieval(
question,
embd_mdl,
kb.tenant_id,
@ -1358,6 +1358,15 @@ def retrieval_test(tenant_id):
highlight=highlight,
rank_feature=label_question(question, kbs)
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
[k.tenant_id for k in kbs],
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
for c in ranks["chunks"]:
c.pop("vector", None)

View File

@ -133,7 +133,7 @@ def init_llm_factory():
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
TenantService.filter_update([1 == 1], {
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"})
## insert openai two embedding models to the current openai user.
# print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])

View File

@ -197,8 +197,7 @@ def chat(dialog, messages, stream=True, **kwargs):
embedding_model_name = embedding_list[0]
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
retriever = settings.retrievaler
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
@ -275,6 +274,14 @@ def chat(dialog, messages, stream=True, **kwargs):
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
tenant_ids,
dialog.kb_ids,
embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
retrieval_ts = timer()

View File

@ -28,7 +28,7 @@ from peewee import fn
from api.db.db_utils import bulk_insert_into_db
from api import settings
from api.utils import current_timestamp, get_format_time, get_uuid
from graphrag.mind_map_extractor import MindMapExtractor
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search, rag_tokenizer
@ -105,8 +105,19 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
cls.clear_chunk_num(doc.id)
try:
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
except Exception:
pass
return cls.delete_by_id(doc.id)
@classmethod
@ -142,7 +153,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run]
cls.model.run, cls.model.parser_id]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
@ -295,9 +306,9 @@ class DocumentService(CommonService):
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
@ -365,6 +376,12 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def update_progress(cls):
MSG = {
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
"graphrag": "Start Graph Extraction",
"graph_resolution": "Start Graph Resolution",
"graph_community": "Start Graph Community Reports Generation"
}
docs = cls.get_unfinished_docs()
for d in docs:
try:
@ -390,15 +407,27 @@ class DocumentService(CommonService):
prg = -1
status = TaskStatus.FAIL.value
elif finished:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
" raptor") < 0:
queue_raptor_tasks(d)
m = "\n".join(sorted(msg))
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("resolution") \
and m.find(MSG["graph_resolution"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("community") \
and m.find(MSG["graph_community"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ RAPTOR -------")
else:
status = TaskStatus.DONE.value
msg = "\n".join(msg)
msg = "\n".join(sorted(msg))
info = {
"process_duation": datetime.timestamp(
datetime.now()) -
@ -430,7 +459,7 @@ class DocumentService(CommonService):
return False
def queue_raptor_tasks(doc):
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
@ -443,15 +472,16 @@ def queue_raptor_tasks(doc):
"doc_id": doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval)."
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
}
task = new_task()
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
task["type"] = "raptor"
task["task_type"] = ty
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
@ -489,7 +519,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
doc_nm = {}
@ -592,4 +622,4 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d, _ in files]
return [d["id"] for d, _ in files]

View File

@ -401,7 +401,7 @@ class FileService(CommonService):
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for file in file_objs:

View File

@ -16,7 +16,6 @@
import os
import random
import xxhash
import bisect
from datetime import datetime
from api.db.db_utils import bulk_insert_into_db
@ -183,7 +182,7 @@ class TaskService(CommonService):
if os.environ.get("MACOS"):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 1000)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
@ -194,7 +193,7 @@ class TaskService(CommonService):
with DB.lock("update_progress", -1):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 1000)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
@ -210,12 +209,12 @@ def queue_tasks(doc: dict, bucket: str, name: str):
if doc["type"] == FileType.PDF.value:
file_bin = STORAGE_IMPL.get(bucket, name)
do_layout = doc["parser_config"].get("layout_recognize", True)
do_layout = doc["parser_config"].get("layout_recognize", "DeepDOC")
pages = PdfParser.total_page_number(doc["name"], file_bin)
page_size = doc["parser_config"].get("task_page_size", 12)
if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size", 22)
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC":
page_size = 10 ** 9
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
for s, e in page_ranges:
@ -243,6 +242,10 @@ def queue_tasks(doc: dict, bucket: str, name: str):
for task in parse_task_array:
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
if field == "parser_config":
for k in ["raptor", "graphrag"]:
if k in chunking_config[field]:
del chunking_config[field][k]
hasher.update(str(chunking_config[field]).encode("utf-8"))
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
@ -276,20 +279,27 @@ def queue_tasks(doc: dict, bucket: str, name: str):
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
idx = bisect.bisect_left(prev_tasks, (task.get("from_page", 0), task.get("digest", "")),
key=lambda x: (x.get("from_page", 0), x.get("digest", "")))
idx = 0
while idx < len(prev_tasks):
prev_task = prev_tasks[idx]
if prev_task.get("from_page", 0) == task.get("from_page", 0) \
and prev_task.get("digest", 0) == task.get("digest", ""):
break
idx += 1
if idx >= len(prev_tasks):
return 0
prev_task = prev_tasks[idx]
if prev_task["progress"] < 1.0 or prev_task["digest"] != task["digest"] or not prev_task["chunk_ids"]:
if prev_task["progress"] < 1.0 or not prev_task["chunk_ids"]:
return 0
task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0
if "from_page" in task and "to_page" in task:
if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
else:
task["progress_msg"] = ""
task["progress_msg"] += "reused previous task's chunks."
task["progress_msg"] = " ".join(
[datetime.now().strftime("%H:%M:%S"), task["progress_msg"], "Reused previous task's chunks."])
prev_task["chunk_ids"] = ""
return len(task["chunk_ids"].split())

View File

@ -355,7 +355,7 @@ def get_parser_config(chunk_method, parser_config):
if not chunk_method:
chunk_method = "naive"
key_mapping = {
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": True,
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False}},
"qa": {"raptor": {"use_raptor": False}},
"tag": None,

View File

@ -25,9 +25,19 @@
"weight_int": {"type": "integer", "default": 0},
"weight_flt": {"type": "float", "default": 0.0},
"rank_int": {"type": "integer", "default": 0},
"rank_flt": {"type": "float", "default": 0},
"available_int": {"type": "integer", "default": 1},
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"pagerank_fea": {"type": "integer", "default": 0},
"tag_fea": {"type": "integer", "default": 0}
"tag_feas": {"type": "integer", "default": 0},
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"from_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"to_entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entity_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"source_id": {"type": "varchar", "default": ""},
"n_hop_with_weight": {"type": "varchar", "default": ""},
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}
}

View File

@ -1,146 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import json
from dataclasses import dataclass
from graphrag.extractor import Extractor
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM
from rag.utils import num_tokens_from_string
SUMMARIZE_PROMPT = """
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context.
#######
-Data-
Entities: {entity_name}
Description List: {description_list}
#######
Output:
"""
# Max token size for input prompts
DEFAULT_MAX_INPUT_TOKENS = 4_000
# Max token count for LLM answers
DEFAULT_MAX_SUMMARY_LENGTH = 128
@dataclass
class SummarizationResult:
"""Unipartite graph extraction result class definition."""
items: str | tuple[str, str]
description: str
class SummarizeExtractor(Extractor):
"""Unipartite graph extractor class definition."""
_entity_name_key: str
_input_descriptions_key: str
_summarization_prompt: str
_on_error: ErrorHandlerFn
_max_summary_length: int
_max_input_tokens: int
def __init__(
self,
llm_invoker: CompletionLLM,
entity_name_key: str | None = None,
input_descriptions_key: str | None = None,
summarization_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
max_summary_length: int | None = None,
max_input_tokens: int | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._entity_name_key = entity_name_key or "entity_name"
self._input_descriptions_key = input_descriptions_key or "description_list"
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS
def __call__(
self,
items: str | tuple[str, str],
descriptions: list[str],
) -> SummarizationResult:
"""Call method definition."""
result = ""
if len(descriptions) == 0:
result = ""
if len(descriptions) == 1:
result = descriptions[0]
else:
result = self._summarize_descriptions(items, descriptions)
return SummarizationResult(
items=items,
description=result or "",
)
def _summarize_descriptions(
self, items: str | tuple[str, str], descriptions: list[str]
) -> str:
"""Summarize descriptions into a single description."""
sorted_items = sorted(items) if isinstance(items, list) else items
# Safety check, should always be a list
if not isinstance(descriptions, list):
descriptions = [descriptions]
# Iterate over descriptions, adding all until the max input tokens is reached
usable_tokens = self._max_input_tokens - num_tokens_from_string(
self._summarization_prompt
)
descriptions_collected = []
result = ""
for i, description in enumerate(descriptions):
usable_tokens -= num_tokens_from_string(description)
descriptions_collected.append(description)
# If buffer is full, or all descriptions have been added, summarize
if (usable_tokens < 0 and len(descriptions_collected) > 1) or (
i == len(descriptions) - 1
):
# Calculate result (final or partial)
result = await self._summarize_descriptions_with_llm(
sorted_items, descriptions_collected
)
# If we go for another loop, reset values to new
if i != len(descriptions) - 1:
descriptions_collected = [result]
usable_tokens = (
self._max_input_tokens
- num_tokens_from_string(self._summarization_prompt)
- num_tokens_from_string(result)
)
return result
def _summarize_descriptions_with_llm(
self, items: str | tuple[str, str] | list[str], descriptions: list[str]
):
"""Summarize descriptions using the LLM."""
variables = {
self._entity_name_key: json.dumps(items),
self._input_descriptions_key: json.dumps(sorted(descriptions)),
}
text = perform_variable_replacements(self._summarization_prompt, variables=variables)
return self._chat("", [{"role": "user", "content": text}])

View File

@ -16,18 +16,18 @@
import logging
import itertools
import re
import traceback
import time
from dataclasses import dataclass
from typing import Any
from typing import Any, Callable
import networkx as nx
from graphrag.extractor import Extractor
from graphrag.general.extractor import Extractor
from rag.nlp import is_english
import editdistance
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from graphrag.utils import perform_variable_replacements
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@ -37,8 +37,8 @@ DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
@dataclass
class EntityResolutionResult:
"""Entity resolution result class definition."""
output: nx.Graph
graph: nx.Graph
removed_entities: list
class EntityResolution(Extractor):
@ -46,7 +46,6 @@ class EntityResolution(Extractor):
_resolution_prompt: str
_output_formatter_prompt: str
_on_error: ErrorHandlerFn
_record_delimiter_key: str
_entity_index_delimiter_key: str
_resolution_result_delimiter_key: str
@ -54,21 +53,19 @@ class EntityResolution(Extractor):
def __init__(
self,
llm_invoker: CompletionLLM,
resolution_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
record_delimiter_key: str | None = None,
entity_index_delimiter_key: str | None = None,
resolution_result_delimiter_key: str | None = None,
input_text_key: str | None = None
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None
):
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
"""Init method definition."""
self._llm = llm_invoker
self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter"
self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter"
self._input_text_key = input_text_key or "input_text"
self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
self._record_delimiter_key = "record_delimiter"
self._entity_index_dilimiter_key = "entity_index_delimiter"
self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text"
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
"""Call method definition."""
@ -87,11 +84,11 @@ class EntityResolution(Extractor):
}
nodes = graph.nodes
entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes))
entity_types = list(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
node_clusters = {entity_type: [] for entity_type in entity_types}
for node in nodes:
node_clusters[graph.nodes[node]['entity_type']].append(node)
node_clusters[graph.nodes[node].get('entity_type', '-')].append(node)
candidate_resolution = {entity_type: [] for entity_type in entity_types}
for k, v in node_clusters.items():
@ -128,44 +125,51 @@ class EntityResolution(Extractor):
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
except Exception as e:
except Exception:
logging.exception("error entity resolution")
self._on_error(e, traceback.format_exc(), None)
connect_graph = nx.Graph()
removed_entities = []
connect_graph.add_edges_from(resolution_result)
for sub_connect_graph in nx.connected_components(connect_graph):
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
remove_nodes = list(sub_connect_graph.nodes)
keep_node = remove_nodes.pop()
self._merge_nodes(keep_node, self._get_entity_(remove_nodes))
for remove_node in remove_nodes:
removed_entities.append(remove_node)
remove_node_neighbors = graph[remove_node]
graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
remove_node_neighbors = list(remove_node_neighbors)
for remove_node_neighbor in remove_node_neighbors:
rel = self._get_relation_(remove_node, remove_node_neighbor)
if graph.has_edge(remove_node, remove_node_neighbor):
graph.remove_edge(remove_node, remove_node_neighbor)
if remove_node_neighbor == keep_node:
graph.remove_edge(keep_node, remove_node)
if graph.has_edge(keep_node, remove_node):
graph.remove_edge(keep_node, remove_node)
continue
if not rel:
continue
if graph.has_edge(keep_node, remove_node_neighbor):
graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][
'weight']
graph[keep_node][remove_node_neighbor]['description'] += \
graph[remove_node][remove_node_neighbor]['description']
graph.remove_edge(remove_node, remove_node_neighbor)
self._merge_edges(keep_node, remove_node_neighbor, [rel])
else:
graph.add_edge(keep_node, remove_node_neighbor,
weight=graph[remove_node][remove_node_neighbor]['weight'],
description=graph[remove_node][remove_node_neighbor]['description'],
source_id="")
graph.remove_edge(remove_node, remove_node_neighbor)
pair = sorted([keep_node, remove_node_neighbor])
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
self._set_relation_(pair[0], pair[1],
dict(
src_id=pair[0],
tgt_id=pair[1],
weight=rel['weight'],
description=rel['description'],
keywords=[],
source_id=rel.get("source_id", ""),
metadata={"created_at": time.time()}
))
graph.remove_node(remove_node)
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return EntityResolutionResult(
output=graph,
graph=graph,
removed_entities=removed_entities
)
def _process_results(

View File

@ -1,34 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from graphrag.utils import get_llm_cache, set_llm_cache
from rag.llm.chat_model import Base as CompletionLLM
class Extractor:
_llm: CompletionLLM
def __init__(self, llm_invoker: CompletionLLM):
self._llm = llm_invoker
def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm.llm_name, system, history, gen_conf)
if response:
return response
response = self._llm.chat(system, history, gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
return response

View File

View File

@ -15,8 +15,8 @@ from typing import Any
import tiktoken
from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.extractor import Extractor
from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.general.extractor import Extractor
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements

View File

@ -13,10 +13,10 @@ from typing import Callable
from dataclasses import dataclass
import networkx as nx
import pandas as pd
from graphrag import leiden
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.extractor import Extractor
from graphrag.leiden import add_community_info2graph
from graphrag.general import leiden
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor
from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
from rag.utils import num_tokens_from_string
@ -40,32 +40,43 @@ class CommunityReportsExtractor(Extractor):
_max_report_length: int
def __init__(
self,
llm_invoker: CompletionLLM,
extraction_prompt: str | None = None,
on_error: ErrorHandlerFn | None = None,
max_report_length: int | None = None,
self,
llm_invoker: CompletionLLM,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
max_report_length: int | None = None,
):
super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
"""Init method definition."""
self._llm = llm_invoker
self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
self._max_report_length = max_report_length or 1500
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
res_str = []
res_dict = []
over, token_count = 0, 0
st = timer()
for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}")
for cm_id, ents in comm.items():
weight = ents["weight"]
ents = ents["nodes"]
ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents])
ent_df["entity"] = ent_df["entity_name"]
del ent_df["entity_name"]
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
rela_df["source"] = rela_df["src_id"]
rela_df["target"] = rela_df["tgt_id"]
del rela_df["src_id"]
del rela_df["tgt_id"]
prompt_variables = {
"entity_df": ent_df.to_csv(index_label="id"),

View File

@ -9,7 +9,7 @@ from typing import Any
import numpy as np
import networkx as nx
from dataclasses import dataclass
from graphrag.leiden import stable_largest_connected_component
from graphrag.general.leiden import stable_largest_connected_component
import graspologic as gc

View File

@ -0,0 +1,245 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
from collections import defaultdict, Counter
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Callable
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
from rag.llm.chat_model import Base as CompletionLLM
from rag.utils import truncate
GRAPH_FIELD_SEP = "<SEP>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
ENTITY_EXTRACTION_MAX_GLEANINGS = 2
class Extractor:
_llm: CompletionLLM
def __init__(
self,
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
):
self._llm = llm_invoker
self._language = language
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
self._get_entity_ = get_entity
self._set_entity_ = set_entity
self._get_relation_ = get_relation
self._set_relation_ = set_relation
def _chat(self, system, history, gen_conf):
hist = deepcopy(history)
conf = deepcopy(gen_conf)
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
if response:
return response
response = self._llm.chat(system, hist, conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
return response
def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str):
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
ent_types = [t.lower() for t in self._entity_types]
for record in records:
record_attributes = split_string_by_multi_markers(
record, [tuple_delimiter]
)
if_entities = handle_single_entity_extraction(
record_attributes, chunk_key
)
if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
if_relation = handle_single_relationship_extraction(
record_attributes, chunk_key
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
return dict(maybe_nodes), dict(maybe_edges)
def __call__(
self, chunks: list[tuple[str, str]],
callback: Callable | None = None
):
results = []
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
for i, (cid, ck) in enumerate(chunks):
threads.append(
exe.submit(self._process_single_content, (cid, ck)))
for i, _ in enumerate(threads):
n, r, tc = _.result()
if not isinstance(n, Exception):
results.append((n, r))
if callback:
callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
elif callback:
callback(msg="Knowledge graph extraction error:{}".format(str(n)))
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
logging.info("Inserting entities into storage...")
all_entities_data = []
for en_nm, ents in maybe_nodes.items():
all_entities_data.append(self._merge_nodes(en_nm, ents))
logging.info("Inserting relationships into storage...")
all_relationships_data = []
for (src,tgt), rels in maybe_edges.items():
all_relationships_data.append(self._merge_edges(src, tgt, rels))
if not len(all_entities_data) and not len(all_relationships_data):
logging.warning(
"Didn't extract any entities and relationships, maybe your LLM is not working"
)
if not len(all_entities_data):
logging.warning("Didn't extract any entities")
if not len(all_relationships_data):
logging.warning("Didn't extract any relationships")
return all_entities_data, all_relationships_data
def _merge_nodes(self, entity_name: str, entities: list[dict]):
if not entities:
return
already_entity_types = []
already_source_ids = []
already_description = []
already_node = self._get_entity_(entity_name)
if already_node:
already_entity_types.append(already_node["entity_type"])
already_source_ids.extend(already_node["source_id"])
already_description.append(already_node["description"])
entity_type = sorted(
Counter(
[dp["entity_type"] for dp in entities] + already_entity_types
).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in entities] + already_description))
)
already_source_ids = flat_uniq_list(entities, "source_id")
description = self._handle_entity_relation_summary(
entity_name, description
)
node_data = dict(
entity_type=entity_type,
description=description,
source_id=already_source_ids,
)
node_data["entity_name"] = entity_name
self._set_entity_(entity_name, node_data)
return node_data
def _merge_edges(
self,
src_id: str,
tgt_id: str,
edges_data: list[dict]
):
if not edges_data:
return
already_weights = []
already_source_ids = []
already_description = []
already_keywords = []
relation = self._get_relation_(src_id, tgt_id)
if relation:
already_weights = [relation["weight"]]
already_source_ids = relation["source_id"]
already_description = [relation["description"]]
already_keywords = relation["keywords"]
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
)
keywords = flat_uniq_list(edges_data, "keywords") + already_keywords
source_id = flat_uniq_list(edges_data, "source_id") + already_source_ids
for need_insert_id in [src_id, tgt_id]:
if self._get_entity_(need_insert_id):
continue
self._set_entity_(need_insert_id, {
"source_id": source_id,
"description": description,
"entity_type": 'UNKNOWN'
})
description = self._handle_entity_relation_summary(
f"({src_id}, {tgt_id})", description
)
edge_data = dict(
src_id=src_id,
tgt_id=tgt_id,
description=description,
keywords=keywords,
weight=weight,
source_id=source_id
)
self._set_relation_(src_id, tgt_id, edge_data)
return edge_data
def _handle_entity_relation_summary(
self,
entity_or_relation_name: str,
description: str
) -> str:
summary_max_tokens = 512
use_description = truncate(description, summary_max_tokens)
prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
context_base = dict(
entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
language=self._language,
)
use_prompt = prompt_template.format(**context_base)
logging.info(f"Trigger summary: {entity_or_relation_name}")
summary = self._chat(use_prompt, [{"role": "assistant", "content": "Output: "}], {"temperature": 0.8})
return summary

View File

@ -0,0 +1,154 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import re
from typing import Any, Callable
from dataclasses import dataclass
import tiktoken
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
output: nx.Graph
source_docs: dict[Any, Any]
class GraphExtractor(Extractor):
"""Unipartite graph extractor class definition."""
_join_descriptions: bool
_tuple_delimiter_key: str
_record_delimiter_key: str
_entity_types_key: str
_input_text_key: str
_completion_delimiter_key: str
_entity_name_key: str
_input_descriptions_key: str
_extraction_prompt: str
_summarization_prompt: str
_loop_args: dict[str, Any]
_max_gleanings: int
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
input_text_key: str | None = None,
entity_types_key: str | None = None,
completion_delimiter_key: str | None = None,
join_descriptions=True,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._join_descriptions = join_descriptions
self._input_text_key = input_text_key or "input_text"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._entity_types_key = entity_types_key or "entity_types"
self._extraction_prompt = GRAPH_EXTRACTION_PROMPT
self._max_gleanings = (
max_gleanings
if max_gleanings is not None
else ENTITY_EXTRACTION_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
# Construct the looping arguments
encoding = tiktoken.get_encoding("cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
# Wire defaults into the prompt variables
self._prompt_variables = {
"entity_types": entity_types,
self._tuple_delimiter_key: DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: DEFAULT_COMPLETION_DELIMITER,
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
}
def _process_single_content(self,
chunk_key_dp: tuple[str, str]
):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
variables = {
**self._prompt_variables,
self._input_text_key: content,
}
try:
gen_conf = {"temperature": 0.3}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
history = [{"role": "system", "content": hint_prompt}, {"role": "assistant", "content": response}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._chat("", history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, self._loop_args)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "YES":
break
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
records = [r for r in records if r.strip()]
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
return maybe_nodes, maybe_edges, token_count
except Exception as e:
logging.exception("error extracting graph")
return e, None, None

View File

@ -106,4 +106,19 @@ Text: {input_text}
Output:"""
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n"
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n"
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n"
SUMMARIZE_DESCRIPTIONS_PROMPT = """
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
Make sure it is written in third person, and include the entity names so we the have full context.
Use {language} as output language.
#######
-Data-
Entities: {entity_name}
Description List: {description_list}
#######
"""

197
graphrag/general/index.py Normal file
View File

@ -0,0 +1,197 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
from functools import reduce, partial
import networkx as nx
from api import settings
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
from graphrag.entity_resolution import EntityResolution
from graphrag.general.extractor import Extractor
from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES
from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
chunk_id, update_nodes_pagerank_nhop_neighbour
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock
class Dealer:
def __init__(self,
extractor: Extractor,
tenant_id: str,
kb_id: str,
llm_bdl,
chunks: list[tuple[str, str]],
language,
entity_types=DEFAULT_ENTITY_TYPES,
embed_bdl=None,
callback=None
):
docids = list(set([docid for docid,_ in chunks]))
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
ext = extractor(self.llm_bdl, language=language,
entity_types=entity_types,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
)
ents, rels = ext(chunks, callback)
self.graph = nx.Graph()
for en in ents:
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
for rel in rels:
self.graph.add_edge(
rel["src_id"],
rel["tgt_id"],
weight=rel["weight"],
#description=rel["description"]
)
with RedisDistributedLock(kb_id, 60*60):
old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph])
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
if old_doc_ids:
docids.extend(old_doc_ids)
docids = list(set(docids))
set_graph(tenant_id, kb_id, self.graph, docids)
class WithResolution(Dealer):
def __init__(self,
tenant_id: str,
kb_id: str,
llm_bdl,
embed_bdl=None,
callback=None
):
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
with RedisDistributedLock(kb_id, 60*60):
self.graph, doc_ids = get_graph(tenant_id, kb_id)
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
if callback:
callback(-1, msg="Faild to fetch the graph.")
return
if callback:
callback(msg="Fetch the existing graph.")
er = EntityResolution(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
reso = er(self.graph)
self.graph = reso.graph
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if callback:
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
set_graph(tenant_id, kb_id, self.graph, doc_ids)
settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"from_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"to_entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
settings.docStoreConn.delete({
"knowledge_graph_kwd": "entity",
"kb_id": kb_id,
"entity_kwd": reso.removed_entities
}, search.index_name(tenant_id), kb_id)
class WithCommunity(Dealer):
def __init__(self,
tenant_id: str,
kb_id: str,
llm_bdl,
embed_bdl=None,
callback=None
):
self.community_structure = None
self.community_reports = None
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
with RedisDistributedLock(kb_id, 60*60):
self.graph, doc_ids = get_graph(tenant_id, kb_id)
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
if callback:
callback(-1, msg="Faild to fetch the graph.")
return
if callback:
callback(msg="Fetch the existing graph.")
cr = CommunityReportsExtractor(self.llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
cr = cr(self.graph, callback=callback)
self.community_structure = cr.structured_output
self.community_reports = cr.output
set_graph(tenant_id, kb_id, self.graph, doc_ids)
if callback:
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
settings.docStoreConn.delete({
"knowledge_graph_kwd": "community_report",
"kb_id": kb_id
}, search.index_name(tenant_id), kb_id)
for stru, rep in zip(self.community_structure, self.community_reports):
obj = {
"report": rep,
"evidences": "\n".join([f["explanation"] for f in stru["findings"]])
}
chunk = {
"docnm_kwd": stru["title"],
"title_tks": rag_tokenizer.tokenize(stru["title"]),
"content_with_weight": json.dumps(obj, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]),
"knowledge_graph_kwd": "community_report",
"weight_flt": stru["weight"],
"entities_kwd": stru["entities"],
"important_kwd": stru["entities"],
"kb_id": kb_id,
"source_id": doc_ids,
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
#try:
# ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
#except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}")
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))

View File

@ -10,7 +10,6 @@ import html
from typing import Any, cast
from graspologic.partition import hierarchical_leiden
from graspologic.utils import largest_connected_component
import networkx as nx
from networkx import is_empty
@ -130,6 +129,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
if not weights:
continue
max_weight = max(weights)
if max_weight == 0:
continue
for _, comm in result.items():
comm["weight"] /= max_weight

View File

@ -23,8 +23,8 @@ from typing import Any
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from graphrag.extractor import Extractor
from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.general.extractor import Extractor
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from rag.llm.chat_model import Base as CompletionLLM
import markdown_to_json

63
graphrag/general/smoke.py Normal file
View File

@ -0,0 +1,63 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import json
import networkx as nx
from api import settings
from api.db import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from graphrag.general.index import WithCommunity, Dealer, WithResolution
from graphrag.light.graph_extractor import GraphExtractor
from rag.utils.redis_conn import RedisDistributedLock
settings.init_settings()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
args = parser.parse_args()
e, doc = DocumentService.get_by_id(args.doc_id)
if not e:
raise LookupError("Document not found.")
kb_id = doc.kb_id
chunks = [d["content_with_weight"] for d in
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
fields=["content_with_weight"])]
chunks = [("x", c) for c in chunks]
RedisDistributedLock.clean_lock(kb_id)
_, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))

View File

@ -1,322 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import numbers
import re
import traceback
from typing import Any, Callable, Mapping
from dataclasses import dataclass
import tiktoken
from graphrag.extractor import Extractor
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
ENTITY_EXTRACTION_MAX_GLEANINGS = 1
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
output: nx.Graph
source_docs: dict[Any, Any]
class GraphExtractor(Extractor):
"""Unipartite graph extractor class definition."""
_join_descriptions: bool
_tuple_delimiter_key: str
_record_delimiter_key: str
_entity_types_key: str
_input_text_key: str
_completion_delimiter_key: str
_entity_name_key: str
_input_descriptions_key: str
_extraction_prompt: str
_summarization_prompt: str
_loop_args: dict[str, Any]
_max_gleanings: int
_on_error: ErrorHandlerFn
def __init__(
self,
llm_invoker: CompletionLLM,
prompt: str | None = None,
tuple_delimiter_key: str | None = None,
record_delimiter_key: str | None = None,
input_text_key: str | None = None,
entity_types_key: str | None = None,
completion_delimiter_key: str | None = None,
join_descriptions=True,
encoding_model: str | None = None,
max_gleanings: int | None = None,
on_error: ErrorHandlerFn | None = None,
):
"""Init method definition."""
# TODO: streamline construction
self._llm = llm_invoker
self._join_descriptions = join_descriptions
self._input_text_key = input_text_key or "input_text"
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
self._completion_delimiter_key = (
completion_delimiter_key or "completion_delimiter"
)
self._entity_types_key = entity_types_key or "entity_types"
self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
self._max_gleanings = (
max_gleanings
if max_gleanings is not None
else ENTITY_EXTRACTION_MAX_GLEANINGS
)
self._on_error = on_error or (lambda _e, _s, _d: None)
self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
# Construct the looping arguments
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
yes = encoding.encode("YES")
no = encoding.encode("NO")
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
def __call__(
self, texts: list[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None
) -> GraphExtractionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
all_records: dict[int, str] = {}
source_doc_map: dict[int, str] = {}
# Wire defaults into the prompt variables
prompt_variables = {
**prompt_variables,
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
or DEFAULT_TUPLE_DELIMITER,
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
or DEFAULT_RECORD_DELIMITER,
self._completion_delimiter_key: prompt_variables.get(
self._completion_delimiter_key
)
or DEFAULT_COMPLETION_DELIMITER,
self._entity_types_key: ",".join(
prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
),
}
st = timer()
total = len(texts)
total_token_count = 0
for doc_index, text in enumerate(texts):
try:
# Invoke the entity extraction
result, token_count = self._process_document(text, prompt_variables)
source_doc_map[doc_index] = text
all_records[doc_index] = result
total_token_count += token_count
if callback:
callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e:
if callback:
callback(msg="Knowledge graph extraction error:{}".format(str(e)))
logging.exception("error extracting graph")
self._on_error(
e,
traceback.format_exc(),
{
"doc_index": doc_index,
"text": text,
},
)
output = self._process_results(
all_records,
prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
)
return GraphExtractionResult(
output=output,
source_docs=source_doc_map,
)
def _process_document(
self, text: str, prompt_variables: dict[str, str]
) -> str:
variables = {
**prompt_variables,
self._input_text_key: text,
}
token_count = 0
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.3}
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
token_count = num_tokens_from_string(text + response)
results = response or ""
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._chat("", history, gen_conf)
results += response or ""
# if this is the final glean, don't bother updating the continuation flag
if i >= self._max_gleanings - 1:
break
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
continuation = self._chat("", history, self._loop_args)
if continuation != "YES":
break
return results, token_count
def _process_results(
self,
results: dict[int, str],
tuple_delimiter: str,
record_delimiter: str,
) -> nx.Graph:
"""Parse the result string to create an undirected unipartite graph.
Args:
- results - dict of results from the extraction chain
- tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
- record_delimiter - delimiter between records, default is '##'
Returns:
- output - unipartite graph in graphML format
"""
graph = nx.Graph()
for source_doc_id, extracted_data in results.items():
records = [r.strip() for r in extracted_data.split(record_delimiter)]
for record in records:
record = re.sub(r"^\(|\)$", "", record.strip())
record_attributes = record.split(tuple_delimiter)
if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
if entity_name in graph.nodes():
node = graph.nodes[entity_name]
if self._join_descriptions:
node["description"] = "\n".join(
list({
*_unpack_descriptions(node),
entity_description,
})
)
else:
if len(entity_description) > len(node["description"]):
node["description"] = entity_description
node["source_id"] = ", ".join(
list({
*_unpack_source_ids(node),
str(source_doc_id),
})
)
node["entity_type"] = (
entity_type if entity_type != "" else node["entity_type"]
)
else:
graph.add_node(
entity_name,
entity_type=entity_type,
description=entity_description,
source_id=str(source_doc_id),
weight=1
)
if (
record_attributes[0] == '"relationship"'
and len(record_attributes) >= 5
):
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_source_id = clean_str(str(source_doc_id))
weight = (
float(record_attributes[-1])
if isinstance(record_attributes[-1], numbers.Number)
else 1.0
)
if source not in graph.nodes():
graph.add_node(
source,
entity_type="",
description="",
source_id=edge_source_id,
weight=1
)
if target not in graph.nodes():
graph.add_node(
target,
entity_type="",
description="",
source_id=edge_source_id,
weight=1
)
if graph.has_edge(source, target):
edge_data = graph.get_edge_data(source, target)
if edge_data is not None:
weight += edge_data["weight"]
if self._join_descriptions:
edge_description = "\n".join(
list({
*_unpack_descriptions(edge_data),
edge_description,
})
)
edge_source_id = ", ".join(
list({
*_unpack_source_ids(edge_data),
str(source_doc_id),
})
)
graph.add_edge(
source,
target,
weight=weight,
description=edge_description,
source_id=edge_source_id,
)
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return graph
def _unpack_descriptions(data: Mapping) -> list[str]:
value = data.get("description", None)
return [] if value is None else value.split("\n")
def _unpack_source_ids(data: Mapping) -> list[str]:
value = data.get("source_id", None)
return [] if value is None else value.split(", ")

View File

@ -1,153 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
from concurrent.futures import ThreadPoolExecutor
import json
from functools import reduce
import networkx as nx
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from graphrag.community_reports_extractor import CommunityReportsExtractor
from graphrag.entity_resolution import EntityResolution
from graphrag.graph_extractor import GraphExtractor, DEFAULT_ENTITY_TYPES
from graphrag.mind_map_extractor import MindMapExtractor
from rag.nlp import rag_tokenizer
from rag.utils import num_tokens_from_string
def graph_merge(g1, g2):
g = g2.copy()
for n, attr in g1.nodes(data=True):
if n not in g2.nodes():
g.add_node(n, **attr)
continue
g.nodes[n]["weight"] += 1
if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0:
g.nodes[n]["description"] += "\n" + attr["description"]
for source, target, attr in g1.edges(data=True):
if g.has_edge(source, target):
g[source][target].update({"weight": attr["weight"]+1})
continue
g.add_edge(source, target, **attr)
for node_degree in g.degree:
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return g
def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
_, tenant = TenantService.get_by_id(tenant_id)
llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
ext = GraphExtractor(llm_bdl)
left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
left_token_count = max(llm_bdl.max_length * 0.6, left_token_count)
assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
BATCH_SIZE=4
texts, graphs = [], []
cnt = 0
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
with ThreadPoolExecutor(max_workers=max_workers) as exe:
threads = []
for i in range(len(chunks)):
tkn_cnt = num_tokens_from_string(chunks[i])
if cnt+tkn_cnt >= left_token_count and texts:
for b in range(0, len(texts), BATCH_SIZE):
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
texts = []
cnt = 0
texts.append(chunks[i])
cnt += tkn_cnt
if texts:
for b in range(0, len(texts), BATCH_SIZE):
threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
callback(0.5, "Extracting entities.")
graphs = []
for i, _ in enumerate(threads):
graphs.append(_.result().output)
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
er = EntityResolution(llm_bdl)
graph = er(graph).output
_chunks = chunks
chunks = []
for n, attr in graph.nodes(data=True):
if attr.get("rank", 0) == 0:
logging.debug(f"Ignore entity: {n}")
continue
chunk = {
"name_kwd": n,
"important_kwd": [n],
"title_tks": rag_tokenizer.tokenize(n),
"content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(attr["description"]),
"knowledge_graph_kwd": "entity",
"rank_int": attr["rank"],
"weight_int": attr["weight"]
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)
callback(0.6, "Extracting community reports.")
cr = CommunityReportsExtractor(llm_bdl)
cr = cr(graph, callback=callback)
for community, desc in zip(cr.structured_output, cr.output):
chunk = {
"title_tks": rag_tokenizer.tokenize(community["title"]),
"content_with_weight": desc,
"content_ltks": rag_tokenizer.tokenize(desc),
"knowledge_graph_kwd": "community_report",
"weight_flt": community["weight"],
"entities_kwd": community["entities"],
"important_kwd": community["entities"]
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk)
chunks.append(
{
"content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2),
"knowledge_graph_kwd": "graph"
})
callback(0.75, "Extracting mind graph.")
mindmap = MindMapExtractor(llm_bdl)
mg = mindmap(_chunks).output
if not len(mg.keys()):
return chunks
logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
chunks.append(
{
"content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2),
"knowledge_graph_kwd": "mind_map"
})
return chunks

View File

View File

@ -0,0 +1,127 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
"""
import logging
import re
from typing import Any, Callable
from dataclasses import dataclass
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from graphrag.light.graph_prompt import PROMPTS
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
@dataclass
class GraphExtractionResult:
"""Unipartite graph extraction result class definition."""
output: nx.Graph
source_docs: dict[Any, Any]
class GraphExtractor(Extractor):
_max_gleanings: int
def __init__(
self,
llm_invoker: CompletionLLM,
language: str | None = "English",
entity_types: list[str] | None = None,
get_entity: Callable | None = None,
set_entity: Callable | None = None,
get_relation: Callable | None = None,
set_relation: Callable | None = None,
example_number: int = 2,
max_gleanings: int | None = None,
):
super().__init__(llm_invoker, language, entity_types, get_entity, set_entity, get_relation, set_relation)
"""Init method definition."""
self._max_gleanings = (
max_gleanings
if max_gleanings is not None
else ENTITY_EXTRACTION_MAX_GLEANINGS
)
self._example_number = example_number
examples = "\n".join(
PROMPTS["entity_extraction_examples"][: int(self._example_number)]
)
example_context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(self._entity_types),
language=self._language,
)
# add example's format
examples = examples.format(**example_context_base)
self._entity_extract_prompt = PROMPTS["entity_extraction"]
self._context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(self._entity_types),
examples=examples,
language=self._language,
)
self._continue_prompt = PROMPTS["entiti_continue_extraction"]
self._if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
self._left_token_count = llm_invoker.max_length - num_tokens_from_string(
self._entity_extract_prompt.format(
**self._context_base, input_text="{input_text}"
).format(**self._context_base, input_text="")
)
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
def _process_single_content(self, chunk_key_dp: tuple[str, str]):
token_count = 0
chunk_key = chunk_key_dp[0]
content = chunk_key_dp[1]
hint_prompt = self._entity_extract_prompt.format(
**self._context_base, input_text="{input_text}"
).format(**self._context_base, input_text=content)
try:
gen_conf = {"temperature": 0.3}
final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(self._max_gleanings):
glean_result = self._chat(self._continue_prompt, history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + glean_result + self._continue_prompt)
history += pack_user_ass_to_openai_messages(self._continue_prompt, glean_result)
final_result += glean_result
if now_glean_index == self._max_gleanings - 1:
break
if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
records = split_string_by_multi_markers(
final_result,
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
)
rcds = []
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
rcds.append(record.group(1))
records = rcds
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
return maybe_nodes, maybe_edges, token_count
except Exception as e:
logging.exception("error extracting graph")
return e, None, None

View File

@ -0,0 +1,255 @@
# Licensed under the MIT License
"""
Reference:
- [LightRag](https://github.com/HKUDS/LightRAG)
"""
PROMPTS = {}
PROMPTS["DEFAULT_LANGUAGE"] = "English"
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS["process_tickers"] = ["", "", "", "", "", "", "", "", "", ""]
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
PROMPTS["entity_extraction"] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
Use {language} as output language.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
- entity_type: One of the following types: [{entity_types}]
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>)
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter}
######################
-Examples-
######################
{examples}
#############################
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
"""
PROMPTS["entity_extraction_examples"] = [
"""Example 1:
Entity_types: [person, technology, mission, organization, location]
Text:
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
################
Output:
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
#############################""",
"""Example 2:
Entity_types: [person, technology, mission, organization, location]
Text:
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocolsit demanded a new perspective, a new resolve.
Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
#############
Output:
("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
#############################""",
"""Example 3:
Entity_types: [person, role, technology, organization, event, location, concept]
Text:
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
Alex surveyed his teameach face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
#############
Output:
("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter}
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
#############################""",
]
PROMPTS[
"entiti_continue_extraction"
] = """MANY entities were missed in the last extraction. Add them below using the same format:
"""
PROMPTS[
"entiti_if_loop_extraction"
] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
"""
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
PROMPTS["rag_response"] = """---Role---
You are a helpful assistant responding to questions about data in the tables provided.
---Goal---
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
If you don't know the answer, just say so. Do not make anything up.
Do not include information where the supporting evidence for it is not provided.
When handling relationships with timestamps:
1. Each relationship has a "created_at" timestamp indicating when we acquired this knowledge
2. When encountering conflicting relationships, consider both the semantic content and the timestamp
3. Don't automatically prefer the most recently created relationships - use judgment based on the context
4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
---Target response length and format---
{response_type}
---Data tables---
{context_data}
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown."""
PROMPTS["naive_rag_response"] = """---Role---
You are a helpful assistant responding to questions about documents provided.
---Goal---
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
If you don't know the answer, just say so. Do not make anything up.
Do not include information where the supporting evidence for it is not provided.
When handling content with timestamps:
1. Each piece of content has a "created_at" timestamp indicating when we acquired this knowledge
2. When encountering conflicting information, consider both the content and the timestamp
3. Don't automatically prefer the most recent content - use judgment based on the context
4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
---Target response length and format---
{response_type}
---Documents---
{content_data}
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
"""
PROMPTS[
"similarity_check"
] = """Please analyze the similarity between these two questions:
Question 1: {original_prompt}
Question 2: {cached_prompt}
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
1. Whether these two questions are semantically similar
2. Whether the answer to Question 2 can be used to answer Question 1
Similarity score criteria:
0: Completely unrelated or answer cannot be reused, including but not limited to:
- The questions have different topics
- The locations mentioned in the questions are different
- The times mentioned in the questions are different
- The specific individuals mentioned in the questions are different
- The specific events mentioned in the questions are different
- The background information in the questions is different
- The key conditions in the questions are different
1: Identical and answer can be directly reused
0.5: Partially related and answer needs modification to be used
Return only a number between 0-1, without any additional content.
"""
PROMPTS["mix_rag_response"] = """---Role---
You are a professional assistant responsible for answering questions based on knowledge graph and textual information. Please respond in the same language as the user's question.
---Goal---
Generate a concise response that summarizes relevant points from the provided information. If you don't know the answer, just say so. Do not make anything up or include information where the supporting evidence is not provided.
When handling information with timestamps:
1. Each piece of information (both relationships and content) has a "created_at" timestamp indicating when we acquired this knowledge
2. When encountering conflicting information, consider both the content/relationship and the timestamp
3. Don't automatically prefer the most recent information - use judgment based on the context
4. For time-specific queries, prioritize temporal information in the content before considering creation timestamps
---Data Sources---
1. Knowledge Graph Data:
{kg_context}
2. Vector Data:
{vector_context}
---Response Requirements---
- Target format and length: {response_type}
- Use markdown formatting with appropriate section headings
- Aim to keep content around 3 paragraphs for conciseness
- Each paragraph should be under a relevant section heading
- Each section should focus on one main point or aspect of the answer
- Use clear and descriptive section titles that reflect the content
- List up to 5 most important reference sources at the end under "References", clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (VD)
Format: [KG/VD] Source content
Add sections and commentary to the response as appropriate for the length and format. If the provided information is insufficient to answer the question, clearly state that you don't know or cannot provide an answer in the same language as the user's question."""

58
graphrag/light/smoke.py Normal file
View File

@ -0,0 +1,58 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import json
from api import settings
import networkx as nx
from api.db import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from graphrag.general.index import Dealer
from graphrag.light.graph_extractor import GraphExtractor
from rag.utils.redis_conn import RedisDistributedLock
settings.init_settings()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
args = parser.parse_args()
e, doc = DocumentService.get_by_id(args.doc_id)
if not e:
raise LookupError("Document not found.")
kb_id = doc.kb_id
chunks = [d["content_with_weight"] for d in
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
fields=["content_with_weight"])]
chunks = [("x", c) for c in chunks]
RedisDistributedLock.clean_lock(kb_id)
_, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))

View File

@ -0,0 +1,218 @@
# Licensed under the MIT License
"""
Reference:
- [LightRag](https://github.com/HKUDS/LightRAG)
- [MiniRAG](https://github.com/HKUDS/MiniRAG)
"""
PROMPTS = {}
PROMPTS["minirag_query2kwd"] = """---Role---
You are a helpful assistant tasked with identifying both answer-type and low-level keywords in the user's query.
---Goal---
Given the query, list both answer-type and low-level keywords.
answer_type_keywords focus on the type of the answer to the certain query, while low-level keywords focus on specific entities, details, or concrete terms.
The answer_type_keywords must be selected from Answer type pool.
This pool is in the form of a dictionary, where the key represents the Type you should choose from and the value represents the example samples.
---Instructions---
- Output the keywords in JSON format.
- The JSON should have three keys:
- "answer_type_keywords" for the types of the answer. In this list, the types with the highest likelihood should be placed at the forefront. No more than 3.
- "entities_from_query" for specific entities or details. It must be extracted from the query.
######################
-Examples-
######################
Example 1:
Query: "How does international trade influence global economic stability?"
Answer type pool: {{
'PERSONAL LIFE': ['FAMILY TIME', 'HOME MAINTENANCE'],
'STRATEGY': ['MARKETING PLAN', 'BUSINESS EXPANSION'],
'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'],
'PERSON': ['JANE DOE', 'JOHN SMITH'],
'FOOD': ['PASTA', 'SUSHI'],
'EMOTION': ['HAPPINESS', 'ANGER'],
'PERSONAL EXPERIENCE': ['TRAVEL ABROAD', 'STUDYING ABROAD'],
'INTERACTION': ['TEAM MEETING', 'NETWORKING EVENT'],
'BEVERAGE': ['COFFEE', 'TEA'],
'PLAN': ['ANNUAL BUDGET', 'PROJECT TIMELINE'],
'GEO': ['NEW YORK CITY', 'SOUTH AFRICA'],
'GEAR': ['CAMPING TENT', 'CYCLING HELMET'],
'EMOJI': ['🎉', '🚀'],
'BEHAVIOR': ['POSITIVE FEEDBACK', 'NEGATIVE CRITICISM'],
'TONE': ['FORMAL', 'INFORMAL'],
'LOCATION': ['DOWNTOWN', 'SUBURBS']
}}
################
Output:
{{
"answer_type_keywords": ["STRATEGY","PERSONAL LIFE"],
"entities_from_query": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}}
#############################
Example 2:
Query: "When was SpaceX's first rocket launch?"
Answer type pool: {{
'DATE AND TIME': ['2023-10-10 10:00', 'THIS AFTERNOON'],
'ORGANIZATION': ['GLOBAL INITIATIVES CORPORATION', 'LOCAL COMMUNITY CENTER'],
'PERSONAL LIFE': ['DAILY EXERCISE ROUTINE', 'FAMILY VACATION PLANNING'],
'STRATEGY': ['NEW PRODUCT LAUNCH', 'YEAR-END SALES BOOST'],
'SERVICE FACILITATION': ['REMOTE IT SUPPORT', 'ON-SITE TRAINING SESSIONS'],
'PERSON': ['ALEXANDER HAMILTON', 'MARIA CURIE'],
'FOOD': ['GRILLED SALMON', 'VEGETARIAN BURRITO'],
'EMOTION': ['EXCITEMENT', 'DISAPPOINTMENT'],
'PERSONAL EXPERIENCE': ['BIRTHDAY CELEBRATION', 'FIRST MARATHON'],
'INTERACTION': ['OFFICE WATER COOLER CHAT', 'ONLINE FORUM DEBATE'],
'BEVERAGE': ['ICED COFFEE', 'GREEN SMOOTHIE'],
'PLAN': ['WEEKLY MEETING SCHEDULE', 'MONTHLY BUDGET OVERVIEW'],
'GEO': ['MOUNT EVEREST BASE CAMP', 'THE GREAT BARRIER REEF'],
'GEAR': ['PROFESSIONAL CAMERA EQUIPMENT', 'OUTDOOR HIKING GEAR'],
'EMOJI': ['📅', ''],
'BEHAVIOR': ['PUNCTUALITY', 'HONESTY'],
'TONE': ['CONFIDENTIAL', 'SATIRICAL'],
'LOCATION': ['CENTRAL PARK', 'DOWNTOWN LIBRARY']
}}
################
Output:
{{
"answer_type_keywords": ["DATE AND TIME", "ORGANIZATION", "PLAN"],
"entities_from_query": ["SpaceX", "Rocket launch", "Aerospace", "Power Recovery"]
}}
#############################
Example 3:
Query: "What is the role of education in reducing poverty?"
Answer type pool: {{
'PERSONAL LIFE': ['MANAGING WORK-LIFE BALANCE', 'HOME IMPROVEMENT PROJECTS'],
'STRATEGY': ['MARKETING STRATEGIES FOR Q4', 'EXPANDING INTO NEW MARKETS'],
'SERVICE FACILITATION': ['CUSTOMER SATISFACTION SURVEYS', 'STAFF RETENTION PROGRAMS'],
'PERSON': ['ALBERT EINSTEIN', 'MARIA CALLAS'],
'FOOD': ['PAN-FRIED STEAK', 'POACHED EGGS'],
'EMOTION': ['OVERWHELM', 'CONTENTMENT'],
'PERSONAL EXPERIENCE': ['LIVING ABROAD', 'STARTING A NEW JOB'],
'INTERACTION': ['SOCIAL MEDIA ENGAGEMENT', 'PUBLIC SPEAKING'],
'BEVERAGE': ['CAPPUCCINO', 'MATCHA LATTE'],
'PLAN': ['ANNUAL FITNESS GOALS', 'QUARTERLY BUSINESS REVIEW'],
'GEO': ['THE AMAZON RAINFOREST', 'THE GRAND CANYON'],
'GEAR': ['SURFING ESSENTIALS', 'CYCLING ACCESSORIES'],
'EMOJI': ['💻', '📱'],
'BEHAVIOR': ['TEAMWORK', 'LEADERSHIP'],
'TONE': ['FORMAL MEETING', 'CASUAL CONVERSATION'],
'LOCATION': ['URBAN CITY CENTER', 'RURAL COUNTRYSIDE']
}}
################
Output:
{{
"answer_type_keywords": ["STRATEGY", "PERSON"],
"entities_from_query": ["School access", "Literacy rates", "Job training", "Income inequality"]
}}
#############################
Example 4:
Query: "Where is the capital of the United States?"
Answer type pool: {{
'ORGANIZATION': ['GREENPEACE', 'RED CROSS'],
'PERSONAL LIFE': ['DAILY WORKOUT', 'HOME COOKING'],
'STRATEGY': ['FINANCIAL INVESTMENT', 'BUSINESS EXPANSION'],
'SERVICE FACILITATION': ['ONLINE SUPPORT', 'CUSTOMER SERVICE TRAINING'],
'PERSON': ['ALBERTA SMITH', 'BENJAMIN JONES'],
'FOOD': ['PASTA CARBONARA', 'SUSHI PLATTER'],
'EMOTION': ['HAPPINESS', 'SADNESS'],
'PERSONAL EXPERIENCE': ['TRAVEL ADVENTURE', 'BOOK CLUB'],
'INTERACTION': ['TEAM BUILDING', 'NETWORKING MEETUP'],
'BEVERAGE': ['LATTE', 'GREEN TEA'],
'PLAN': ['WEIGHT LOSS', 'CAREER DEVELOPMENT'],
'GEO': ['PARIS', 'NEW YORK'],
'GEAR': ['CAMERA', 'HEADPHONES'],
'EMOJI': ['🏢', '🌍'],
'BEHAVIOR': ['POSITIVE THINKING', 'STRESS MANAGEMENT'],
'TONE': ['FRIENDLY', 'PROFESSIONAL'],
'LOCATION': ['DOWNTOWN', 'SUBURBS']
}}
################
Output:
{{
"answer_type_keywords": ["LOCATION"],
"entities_from_query": ["capital of the United States", "Washington", "New York"]
}}
#############################
-Real Data-
######################
Query: {query}
Answer type pool:{TYPE_POOL}
######################
Output:
"""
PROMPTS["keywords_extraction"] = """---Role---
You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
---Goal---
Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms.
---Instructions---
- Output the keywords in JSON format.
- The JSON should have two keys:
- "high_level_keywords" for overarching concepts or themes.
- "low_level_keywords" for specific entities or details.
######################
-Examples-
######################
{examples}
#############################
-Real Data-
######################
Query: {query}
######################
The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
Output:
"""
PROMPTS["keywords_extraction_examples"] = [
"""Example 1:
Query: "How does international trade influence global economic stability?"
################
Output:
{
"high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
}
#############################""",
"""Example 2:
Query: "What are the environmental consequences of deforestation on biodiversity?"
################
Output:
{
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
}
#############################""",
"""Example 3:
Query: "What is the role of education in reducing poverty?"
################
Output:
{
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
}
#############################""",
]

View File

@ -14,90 +14,313 @@
# limitations under the License.
#
import json
import logging
from collections import defaultdict
from copy import deepcopy
import json_repair
import pandas as pd
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
from rag.nlp.search import Dealer
from api.utils import get_uuid
from graphrag.query_analyze_prompt import PROMPTS
from graphrag.utils import get_entity_type2sampels, get_llm_cache, set_llm_cache, get_relation
from rag.utils import num_tokens_from_string
from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name
class KGSearch(Dealer):
def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
def merge_into_first(sres, title="") -> dict[str, str]:
if not sres:
return {}
content_with_weight = ""
df, texts = [],[]
for d in sres.values():
try:
df.append(json.loads(d["content_with_weight"]))
except Exception:
texts.append(d["content_with_weight"])
if df:
content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
else:
content_with_weight = title + "\n" + "\n".join(texts)
first_id = ""
first_source = {}
for k, v in sres.items():
first_id = id
first_source = deepcopy(v)
def _chat(self, llm_bdl, system, history, gen_conf):
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
if response:
return response
response = llm_bdl.chat(system, history, gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
return response
def query_rewrite(self, llm, question, idxnms, kb_ids):
ty2ents = get_entity_type2sampels(idxnms, kb_ids)
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {"temperature": .5})
try:
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
return type_keywords, entities_from_query
except json_repair.JSONDecodeError:
try:
result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip()
result = '{' + result.split('{')[1].split('}')[0] + '}'
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
entities_from_query = keywords_data.get("entities_from_query", [])[:5]
return type_keywords, entities_from_query
# Handle parsing error
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
def _ent_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "entity_kwd", "rank_flt",
"n_hop_with_weight"])
for _, ent in es_res.items():
if float(ent.get("_score", 0)) < sim_thr:
continue
if isinstance(ent["entity_kwd"], list):
ent["entity_kwd"] = ent["entity_kwd"][0]
res[ent["entity_kwd"]] = {
"sim": float(ent.get("_score", 0)),
"pagerank": float(ent.get("rank_flt", 0)),
"n_hop_ents": json.loads(ent.get("n_hop_with_weight", "[]")),
"description": ent.get("content_with_weight", "{}")
}
return res
def _relation_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
"weight_int"])
for _, ent in es_res.items():
if float(ent["_score"]) < sim_thr:
continue
f, t = sorted([ent["from_entity_kwd"], ent["to_entity_kwd"]])
if isinstance(f, list):
f = f[0]
if isinstance(t, list):
t = t[0]
res[(f, t)] = {
"sim": float(ent["_score"]),
"pagerank": float(ent.get("weight_int", 0)),
"description": ent["content_with_weight"]
}
return res
def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
if not keywords:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "entity"
matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr)
es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt"], [], filters, [matchDense],
OrderByExpr(), 0, N,
idxnms, kb_ids)
return self._ent_info_from_(es_res, sim_thr)
def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
if not txt:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "relation"
matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr)
es_res = self.dataStore.search(
["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"],
[], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids)
return self._relation_info_from_(es_res, sim_thr)
def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56):
if not types:
return {}
filters = deepcopy(filters)
filters["knowledge_graph_kwd"] = "entity"
filters["entity_type_kwd"] = types
ordr = OrderByExpr()
ordr.desc("rank_flt")
es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N,
idxnms, kb_ids)
return self._ent_info_from_(es_res, 0)
def retrieval(self, question: str,
tenant_ids: str | list[str],
kb_ids: list[str],
emb_mdl,
llm,
max_token: int = 8196,
ent_topn: int = 6,
rel_topn: int = 6,
comm_topn: int = 1,
ent_sim_threshold: float = 0.3,
rel_sim_threshold: float = 0.3,
):
qst = question
filters = self.get_filters({"kb_ids": kb_ids})
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
idxnms = [index_name(tid) for tid in tenant_ids]
ty_kwds = []
ents = []
try:
ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
except Exception as e:
logging.exception(e)
ents = [qst]
pass
ents_from_query = self.get_relevant_ents_by_keywords(ents, filters, idxnms, kb_ids, emb_mdl, ent_sim_threshold)
ents_from_types = self.get_relevant_ents_by_types(ty_kwds, filters, idxnms, kb_ids, 10000)
rels_from_txt = self.get_relevant_relations_by_txt(qst, filters, idxnms, kb_ids, emb_mdl, rel_sim_threshold)
nhop_pathes = defaultdict(dict)
for _, ent in ents_from_query.items():
nhops = ent.get("n_hop_ents", [])
for nbr in nhops:
path = nbr["path"]
wts = nbr["weights"]
for i in range(len(path) - 1):
f, t = path[i], path[i + 1]
if (f, t) in nhop_pathes:
nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i)
else:
nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i)
nhop_pathes[(f, t)]["pagerank"] = wts[i]
logging.info("Retrieved entities: {}".format(list(ents_from_query.keys())))
logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys())))
logging.info("Retrieved entities from types({}): {}".format(ty_kwds, list(ents_from_types.keys())))
logging.info("Retrieved N-hops: {}".format(list(nhop_pathes.keys())))
# P(E|Q) => P(E) * P(Q|E) => pagerank * sim
for ent in ents_from_types.keys():
if ent not in ents_from_query:
continue
ents_from_query[ent]["sim"] *= 2
for (f, t) in rels_from_txt.keys():
pair = tuple(sorted([f, t]))
s = 0
if pair in nhop_pathes:
s += nhop_pathes[pair]["sim"]
del nhop_pathes[pair]
if f in ents_from_types:
s += 1
if t in ents_from_types:
s += 1
rels_from_txt[(f, t)]["sim"] *= s + 1
# This is for the relations from n-hop but not by query search
for (f, t) in nhop_pathes.keys():
s = 0
if f in ents_from_types:
s += 1
if t in ents_from_types:
s += 1
rels_from_txt[(f, t)] = {
"sim": nhop_pathes[(f, t)]["sim"] * (s + 1),
"pagerank": nhop_pathes[(f, t)]["pagerank"]
}
ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
:ent_topn]
rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
:rel_topn]
ents = []
relas = []
for n, ent in ents_from_query:
ents.append({
"Entity": n,
"Score": "%.2f" % (ent["sim"] * ent["pagerank"]),
"Description": json.loads(ent["description"]).get("description", "")
})
max_token -= num_tokens_from_string(str(ents[-1]))
if max_token <= 0:
ents = ents[:-1]
break
first_source["content_with_weight"] = content_with_weight
first_id = next(iter(sres))
return {first_id: first_source}
qst = req.get("question", "")
matchText, keywords = self.qryr.question(qst, min_match=0.05)
condition = self.get_filters(req)
for (f, t), rel in rels_from_txt:
if not rel.get("description"):
for tid in tenant_ids:
rela = get_relation(tid, kb_ids, f, t)
if rela:
break
else:
continue
rel["description"] = rela["description"]
relas.append({
"From Entity": f,
"To Entity": t,
"Score": "%.2f" % (rel["sim"] * rel["pagerank"]),
"Description": json.loads(ent["description"]).get("description", "")
})
max_token -= num_tokens_from_string(str(relas[-1]))
if max_token <= 0:
relas = relas[:-1]
break
## Entity retrieval
condition.update({"knowledge_graph_kwd": ["entity"]})
assert emb_mdl, "No embedding model selected"
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"doc_id", f"q_{len(q_vec)}_vec", "position_int", "name_kwd",
"available_int", "content_with_weight",
"weight_int", "weight_flt"
])
if ents:
ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv())
else:
ents = ""
if relas:
relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv())
else:
relas = ""
fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
ent_res_fields = self.dataStore.getFields(ent_res, src)
entities = [d["name_kwd"] for d in ent_res_fields.values() if d.get("name_kwd")]
ent_ids = self.dataStore.getChunkIds(ent_res)
ent_content = merge_into_first(ent_res_fields, "-Entities-")
if ent_content:
ent_ids = list(ent_content.keys())
return {
"chunk_id": get_uuid(),
"content_ltks": "",
"content_with_weight": ents + relas + self._community_retrival_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
comm_topn, max_token),
"doc_id": "",
"docnm_kwd": "Related content in Knowledge Graph",
"kb_id": kb_ids,
"important_kwd": [],
"image_id": "",
"similarity": 1.,
"vector_similarity": 1.,
"term_similarity": 0,
"vector": [],
"positions": [],
}
def _community_retrival_(self, entities, condition, kb_ids, idxnms, topn, max_token):
## Community retrieval
condition = self.get_filters(req)
condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, src)
comm_ids = self.dataStore.getChunkIds(comm_res)
comm_content = merge_into_first(comm_res_fields, "-Community Report-")
if comm_content:
comm_ids = list(comm_content.keys())
fields = ["docnm_kwd", "content_with_weight"]
odr = OrderByExpr()
odr.desc("weight_flt")
fltr = deepcopy(condition)
fltr["knowledge_graph_kwd"] = "community_report"
fltr["entities_kwd"] = entities
comm_res = self.dataStore.search(fields, [], fltr, [],
OrderByExpr(), 0, topn, idxnms, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, fields)
txts = []
for ii, (_, row) in enumerate(comm_res_fields.items()):
obj = json.loads(row["content_with_weight"])
txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format(
ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"]))
max_token -= num_tokens_from_string(str(txts[-1]))
## Text content retrieval
condition = self.get_filters(req)
condition.update({"knowledge_graph_kwd": ["text"]})
txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
txt_res_fields = self.dataStore.getFields(txt_res, src)
txt_ids = self.dataStore.getChunkIds(txt_res)
txt_content = merge_into_first(txt_res_fields, "-Original Content-")
if txt_content:
txt_ids = list(txt_content.keys())
if not txts:
return ""
return "\n-Community Report-\n" + "\n".join(txts)
return self.SearchResult(
total=len(ent_ids) + len(comm_ids) + len(txt_ids),
ids=[*ent_ids, *comm_ids, *txt_ids],
query_vector=q_vec,
highlight=None,
field={**ent_content, **comm_content, **txt_content},
keywords=[]
)
if __name__ == "__main__":
from api import settings
import argparse
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from rag.nlp import search
settings.init_settings()
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--kb_id', default=False, help="Knowledge base ID", action='store', required=True)
parser.add_argument('-q', '--question', default=False, help="Question", action='store', required=True)
args = parser.parse_args()
kb_id = args.kb_id
_, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
kg = KGSearch(settings.docStoreConn)
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))

View File

@ -1,55 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import json
from graphrag import leiden
from graphrag.community_reports_extractor import CommunityReportsExtractor
from graphrag.entity_resolution import EntityResolution
from graphrag.graph_extractor import GraphExtractor
from graphrag.leiden import add_community_info2graph
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
args = parser.parse_args()
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
graph = ex(docs)
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
graph = er(graph.output)
comm = leiden.run(graph.output, {})
add_community_info2graph(graph.output, comm)
# print(json.dumps(nx.node_link_data(graph.output), ensure_ascii=False,indent=2))
print(json.dumps(comm, ensure_ascii=False, indent=2))
cr = CommunityReportsExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
cr = cr(graph.output)
print("------------------ COMMUNITY REPORT ----------------------\n", cr.output)
print(json.dumps(cr.structured_output, ensure_ascii=False, indent=2))

View File

@ -3,16 +3,26 @@
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
- [LightRag](https://github.com/HKUDS/LightRAG)
"""
import html
import json
import logging
import re
import time
from collections import defaultdict
from copy import deepcopy
from hashlib import md5
from typing import Any, Callable
import networkx as nx
import numpy as np
import xxhash
from networkx.readwrite import json_graph
from api import settings
from rag.nlp import search, rag_tokenizer
from rag.utils.redis_conn import REDIS_CONN
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
@ -131,3 +141,379 @@ def set_tags_to_cache(kb_ids, tags):
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
def graph_merge(g1, g2):
g = g2.copy()
for n, attr in g1.nodes(data=True):
if n not in g2.nodes():
g.add_node(n, **attr)
continue
for source, target, attr in g1.edges(data=True):
if g.has_edge(source, target):
g[source][target].update({"weight": attr.get("weight", 0)+1})
continue
g.add_edge(source, target)#, **attr)
for node_degree in g.degree:
g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
return g
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return dict(
entity_name=entity_name.upper(),
entity_type=entity_type.upper(),
description=entity_description,
source_id=entity_source_id,
)
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_keywords = clean_str(record_attributes[4])
edge_source_id = chunk_key
weight = (
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
)
pair = sorted([source.upper(), target.upper()])
return dict(
src_id=pair[0],
tgt_id=pair[1],
weight=weight,
description=edge_description,
keywords=edge_keywords,
source_id=edge_source_id,
metadata={"created_at": time.time()},
)
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
def get_entity(tenant_id, kb_id, ent_name):
conds = {
"fields": ["content_with_weight"],
"entity_kwd": ent_name,
"size": 10000,
"knowledge_graph_kwd": ["entity"]
}
res = []
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
for id in es_res.ids:
try:
if isinstance(ent_name, str):
return json.loads(es_res.field[id]["content_with_weight"])
res.append(json.loads(es_res.field[id]["content_with_weight"]))
except Exception:
continue
return res
def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
chunk = {
"important_kwd": [ent_name],
"title_tks": rag_tokenizer.tokenize(ent_name),
"entity_kwd": ent_name,
"knowledge_graph_kwd": "entity",
"entity_type_kwd": meta["entity_type"],
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"source_id": list(set(meta["source_id"])),
"kb_id": kb_id,
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
if res.ids:
settings.docStoreConn.update({"entity_kwd": ent_name}, chunk, search.index_name(tenant_id), kb_id)
else:
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
try:
ebd, _ = embd_mdl.encode([ent_name])
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
except Exception as e:
logging.exception(f"Fail to embed entity: {e}")
if ebd is not None:
chunk["q_%d_vec" % len(ebd)] = ebd
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = from_ent_name
if isinstance(ents, str):
ents = [from_ent_name]
if isinstance(to_ent_name, str):
to_ent_name = [to_ent_name]
ents.extend(to_ent_name)
ents = list(set(ents))
conds = {
"fields": ["content_with_weight"],
"size": size,
"from_entity_kwd": ents,
"to_entity_kwd": ents,
"knowledge_graph_kwd": ["relation"]
}
res = []
es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids:
try:
if size == 1:
return json.loads(es_res.field[id]["content_with_weight"])
res.append(json.loads(es_res.field[id]["content_with_weight"]))
except Exception:
continue
return res
def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
chunk = {
"from_entity_kwd": from_ent_name,
"to_entity_kwd": to_ent_name,
"knowledge_graph_kwd": "relation",
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"important_kwd": meta["keywords"],
"source_id": list(set(meta["source_id"])),
"weight_int": int(meta["weight"]),
"kb_id": kb_id,
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
res = settings.retrievaler.search({"from_entity_kwd": to_ent_name, "to_entity_kwd": to_ent_name, "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
if res.ids:
settings.docStoreConn.update({"from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name},
chunk,
search.index_name(tenant_id), kb_id)
else:
txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
try:
ebd, _ = embd_mdl.encode([txt+f": {meta['description']}"])
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
except Exception as e:
logging.exception(f"Fail to embed entity relation: {e}")
if ebd is not None:
chunk["q_%d_vec" % len(ebd)] = ebd
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
def get_graph(tenant_id, kb_id):
conds = {
"fields": ["content_with_weight", "source_id"],
"removed_kwd": "N",
"size": 1,
"knowledge_graph_kwd": ["graph"]
}
res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
for id in res.ids:
try:
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
res.field[id]["source_id"]
except Exception:
continue
return None, None
def set_graph(tenant_id, kb_id, graph, docids):
chunk = {
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
indent=2),
"knowledge_graph_kwd": "graph",
"kb_id": kb_id,
"source_id": list(docids),
"available_int": 0,
"removed_kwd": "N"
}
res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])
if res.ids:
settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
search.index_name(tenant_id), kb_id)
else:
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
def is_continuous_subsequence(subseq, seq):
def find_all_indexes(tup, value):
indexes = []
start = 0
while True:
try:
index = tup.index(value, start)
indexes.append(index)
start = index + 1
except ValueError:
break
return indexes
index_list = find_all_indexes(seq,subseq[0])
for idx in index_list:
if idx!=len(seq)-1:
if seq[idx+1]==subseq[-1]:
return True
return False
def merge_tuples(list1, list2):
result = []
for tup in list1:
last_element = tup[-1]
if last_element in tup[:-1]:
result.append(tup)
else:
matching_tuples = [t for t in list2 if t[0] == last_element]
already_match_flag = 0
for match in matching_tuples:
matchh = (match[1], match[0])
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
continue
already_match_flag = 1
merged_tuple = tup + match[1:]
result.append(merged_tuple)
if not already_match_flag:
result.append(tup)
return result
def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
def n_neighbor(id):
nonlocal graph, n_hop
count = 0
source_edge = list(graph.edges(id))
if not source_edge:
return []
count = count + 1
while count < n_hop:
count = count + 1
sc_edge = deepcopy(source_edge)
source_edge = []
for pair in sc_edge:
append_edge = list(graph.edges(pair[-1]))
for tuples in merge_tuples([pair], append_edge):
source_edge.append(tuples)
nbrs = []
for path in source_edge:
n = {"path": path, "weights": []}
wts = nx.get_edge_attributes(graph, 'weight')
for i in range(len(path)-1):
f, t = path[i], path[i+1]
n["weights"].append(wts.get((f, t), 0))
nbrs.append(n)
return nbrs
pr = nx.pagerank(graph)
for n, p in pr.items():
graph.nodes[n]["pagerank"] = p
try:
settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
{"rank_flt": p,
"n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)},
search.index_name(tenant_id), kb_id)
except Exception as e:
logging.exception(e)
ty2ents = defaultdict(list)
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
ty = graph.nodes[p].get("entity_type")
if not ty or len(ty2ents[ty]) > 12:
continue
ty2ents[ty].append(p)
chunk = {
"content_with_weight": json.dumps(ty2ents, ensure_ascii=False),
"kb_id": kb_id,
"knowledge_graph_kwd": "ty2ents",
"available_int": 0
}
res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
if res.ids:
settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
chunk,
search.index_name(tenant_id), kb_id)
else:
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
def get_entity_type2sampels(idxnms, kb_ids: list):
es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
"size": 10000,
"fields": ["content_with_weight"]},
idxnms, kb_ids)
res = defaultdict(list)
for id in es_res.ids:
smp = es_res.field[id].get("content_with_weight")
if not smp:
continue
try:
smp = json.loads(smp)
except Exception as e:
logging.exception(e)
for ty, ents in smp.items():
res[ty].extend(ents)
return res
def flat_uniq_list(arr, key):
res = []
for a in arr:
a = a[key]
if isinstance(a, list):
res.extend(a)
else:
res.append(a)
return list(set(res))

View File

@ -51,6 +51,7 @@ dependencies = [
"infinity-sdk==0.6.0-dev2",
"infinity-emb>=0.0.66,<0.0.67",
"itsdangerous==2.1.2",
"json-repair==0.35.0",
"markdown==3.6",
"markdown-to-json==2.1.1",
"minio==7.2.4",
@ -130,4 +131,4 @@ full = [
"flagembedding==1.2.10",
"torch==2.3.0",
"transformers==4.38.1"
]
]

View File

@ -88,9 +88,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get(
"parser_config", {}).get(
"layout_recognize", True) else PlainParser()
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)

View File

@ -40,7 +40,7 @@ def chunk(
eng = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
"parser_config",
{"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True},
{"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
)
doc = {
"docnm_kwd": filename,

View File

@ -1,48 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from graphrag.index import build_knowledge_graph_chunks
from rag.app import naive
from rag.nlp import rag_tokenizer, tokenize_chunks
def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?;。;!?", "layout_recognize": True})
eng = lang.lower() == "english"
parser_config["layout_recognize"] = True
sections = naive.chunk(filename, binary, from_page=from_page, to_page=to_page, section_only=True,
parser_config=parser_config, callback=callback)
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
)
for c in chunks:
c["docnm_kwd"] = filename
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
"knowledge_graph_kwd": "text"
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
chunks.extend(tokenize_chunks(sections, doc, eng))
return chunks

View File

@ -162,9 +162,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return tokenize_chunks(chunks, doc, eng, None)
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get(
"parser_config", {}).get(
"layout_recognize", True) else PlainParser()
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
for txt, poss in pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)[0]:
sections.append(txt + poss)

View File

@ -184,9 +184,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
# is it English
eng = lang.lower() == "english" # pdf_parser.is_english
if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get(
"parser_config", {}).get(
"layout_recognize", True) else PlainParser()
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0]) < 3:

View File

@ -202,7 +202,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
is_english = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True})
"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
@ -231,8 +231,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if parser_config.get("layout_recognize", True) else PlainParser()
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
res = tokenize_table(tables, doc, is_english)
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):

View File

@ -84,9 +84,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get(
"parser_config", {}).get(
"layout_recognize", True) else PlainParser()
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, _ = pdf_parser(
filename if not binary else binary, to_page=to_page, callback=callback)
sections = [s for s, _ in sections if s]

View File

@ -144,7 +144,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
"""
if re.search(r"\.pdf$", filename, re.IGNORECASE):
if not kwargs.get("parser_config", {}).get("layout_recognize", True):
if kwargs.get("parser_config", {}).get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
paper = {
"title": filename,

View File

@ -119,9 +119,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
res.append(d)
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get(
"parser_config", {}).get(
"layout_recognize", True) else PlainPdf()
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
for pn, (txt, img) in enumerate(pdf_parser(filename, binary,
from_page=from_page, to_page=to_page, callback=callback)):
d = copy.deepcopy(doc)

View File

@ -32,6 +32,7 @@ import asyncio
LENGTH_NOTIFICATION_CN = "······\n由于长度的原因,回答被截断了,要继续吗?"
LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"
class Base(ABC):
def __init__(self, key, model_name, base_url):
timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600))

View File

@ -59,7 +59,7 @@ class Dealer:
if key in req and req[key] is not None:
condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd", "available_int"]:
for key in ["knowledge_graph_kwd", "available_int", "entity_kwd", "from_entity_kwd", "to_entity_kwd", "removed_kwd"]:
if key in req and req[key] is not None:
condition[key] = req[key]
return condition
@ -198,6 +198,11 @@ class Dealer:
return answer, set([])
ans_v, _ = embd_mdl.encode(pieces_)
for i in range(len(chunk_v)):
if len(ans_v[0]) != len(chunk_v[i]):
chunk_v[i] = [0.0]*len(ans_v[0])
logging.warning("The dimension of query and chunk do not match: {} vs. {}".format(len(ans_v[0]), len(chunk_v[i])))
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
@ -434,6 +439,8 @@ class Dealer:
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items():
doc["id"] = id
if dict_chunks:
res.extend(dict_chunks.values())
if len(dict_chunks.values()) < bs:

View File

@ -19,6 +19,9 @@
import random
import sys
from api.utils.log_utils import initRootLogger
from graphrag.general.index import WithCommunity, WithResolution, Dealer
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
@ -53,7 +56,7 @@ from api import settings
from api.versions import get_ragflow_version
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email, tag
email, tag
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
@ -78,7 +81,7 @@ FACTORY = {
ParserType.ONE.value: one,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email,
ParserType.KG.value: knowledge_graph,
ParserType.KG.value: naive,
ParserType.TAG.value: tag
}
@ -118,7 +121,8 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
if to_page > 0:
if msg:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
if from_page < to_page:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
if msg:
msg = datetime.now().strftime("%H:%M:%S") + " " + msg
d = {"progress_msg": msg}
@ -177,8 +181,7 @@ def collect():
logging.info(f"collect task {msg['id']} {state}")
return None
if msg.get("type", "") == "raptor":
task["task_type"] = "raptor"
task["task_type"] = msg.get("task_type", "")
return task
@ -382,11 +385,9 @@ def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vts, _ = embd_mdl.encode(["ok"])
vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size
def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
@ -422,7 +423,24 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
res.append(d)
tk_count += num_tokens_from_string(content)
return res, tk_count, vector_size
return res, tk_count
def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))
Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
chunks=chunks,
language=language,
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
def do_handle_task(task):
@ -466,14 +484,17 @@ def do_handle_task(task):
logging.exception(error_message)
raise
vts, _ = embedding_model.encode(["ok"])
vector_size = len(vts[0])
init_kb(task, vector_size)
# Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor":
try:
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
chunks, token_count, vector_size = run_raptor(task, chat_model, embedding_model, progress_callback)
chunks, token_count = run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
except TaskCanceledException:
raise
except Exception as e:
@ -481,6 +502,55 @@ def do_handle_task(task):
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
# Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return
elif task.get("task_type", "") == "graph_resolution":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithResolution(
task["tenant_id"], str(task["kb_id"]),chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by Knowledge Graph resolution: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return
elif task.get("task_type", "") == "graph_community":
start_ts = timer()
try:
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
progress_callback(prog=1.0, msg="GraphRAG community reports generation is done ({:.2f}s)".format(timer() - start_ts))
except TaskCanceledException:
raise
except Exception as e:
error_message = f'Fail to bind LLM used by GraphRAG community reports generation: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
return
else:
# Standard chunking methods
start_ts = timer()
@ -507,8 +577,6 @@ def do_handle_task(task):
logging.info(progress_message)
progress_callback(msg=progress_message)
# logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}")
init_kb(task, vector_size)
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
doc_store_result = ""

View File

@ -123,6 +123,7 @@ class ESConnection(DocStoreConnection):
logger.exception("ESConnection.indexExist got exception")
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
break
return False
"""
@ -238,6 +239,7 @@ class ESConnection(DocStoreConnection):
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=indexNames,
body=q,
timeout="600s",
@ -325,8 +327,9 @@ class ESConnection(DocStoreConnection):
except Exception as e:
logger.exception(
f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
if str(e).find("Timeout") > 0:
if re.search(r"(timeout|connection)", str(e).lower()):
continue
break
return False
# update unspecific maybe-multiple documents
@ -364,9 +367,12 @@ class ESConnection(DocStoreConnection):
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
scripts.append(f"ctx._source.{k} = '{v}'")
elif isinstance(v, int):
scripts.append(f"ctx._source.{k} = {v}")
v = re.sub(r"(['\n\r]|\\.)", " ", v)
scripts.append(f"ctx._source.{k}='{v}';")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}={json.dumps(v, ensure_ascii=False)};")
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
@ -383,9 +389,10 @@ class ESConnection(DocStoreConnection):
_ = ubq.execute()
return True
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
if re.search(r"(timeout|connection|conflict)", str(e).lower()):
continue
break
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
@ -399,7 +406,16 @@ class ESConnection(DocStoreConnection):
else:
qry = Q("bool")
for k, v in condition.items():
if isinstance(v, list):
if k == "exists":
qry.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
qry.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
@ -408,6 +424,7 @@ class ESConnection(DocStoreConnection):
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
print(Search().query(qry).to_dict(), flush=True)
res = self.es.delete_by_query(
index=indexName,
body=Search().query(qry).to_dict(),
@ -415,7 +432,7 @@ class ESConnection(DocStoreConnection):
return res["deleted"]
except Exception as e:
logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
if re.search(r"(timeout|connection)", str(e).lower()):
time.sleep(3)
continue
if re.search(r"(not_found)", str(e), re.IGNORECASE):

View File

@ -16,6 +16,8 @@
import logging
import json
import time
import uuid
import valkey as redis
from rag import settings
@ -278,3 +280,32 @@ class RedisDB:
REDIS_CONN = RedisDB()
class RedisDistributedLock:
def __init__(self, lock_key, timeout=10):
self.lock_key = lock_key
self.lock_value = str(uuid.uuid4())
self.timeout = timeout
@staticmethod
def clean_lock(lock_key):
REDIS_CONN.REDIS.delete(lock_key)
def acquire_lock(self):
end_time = time.time() + self.timeout
while time.time() < end_time:
if REDIS_CONN.REDIS.setnx(self.lock_key, self.lock_value):
return True
time.sleep(1)
return False
def release_lock(self):
if REDIS_CONN.REDIS.get(self.lock_key) == self.lock_value:
REDIS_CONN.REDIS.delete(self.lock_key)
def __enter__(self):
self.acquire_lock()
def __exit__(self, exception_type, exception_value, exception_traceback):
self.release_lock()

6055
uv.lock generated

File diff suppressed because it is too large Load Diff