Integration with Infinity (#2894)

### What problem does this PR solve?

Integration with Infinity

- Replaced ELASTICSEARCH with dataStoreConn
- Renamed deleteByQuery with delete
- Renamed bulk to upsertBulk
- getHighlight, getAggregation
- Fix KGSearch.search
- Moved Dealer.sql_retrieval to es_conn.py


### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu 2024-11-12 14:59:41 +08:00 committed by GitHub
parent 00b6000b76
commit f4c52371ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 2647 additions and 1878 deletions

View File

@ -78,7 +78,7 @@ jobs:
echo "Waiting for service to be available..."
sleep 5
done
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest --tb=short t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
- name: Stop ragflow:dev
if: always() # always run this step even if previous steps failed

View File

@ -285,7 +285,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
git clone https://github.com/infiniflow/ragflow.git
cd ragflow/
export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
~/.local/bin/poetry install --sync --no-root # install RAGFlow dependent python modules
~/.local/bin/poetry install --sync --no-root --with=full # install RAGFlow dependent python modules
```
3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
@ -295,7 +295,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
```
127.0.0.1 es01 mysql minio redis
127.0.0.1 es01 infinity mysql minio redis
```
In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.

View File

@ -250,7 +250,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
`/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
```
127.0.0.1 es01 mysql minio redis
127.0.0.1 es01 infinity mysql minio redis
```
**docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).

View File

@ -254,7 +254,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
`/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
```
127.0.0.1 es01 mysql minio redis
127.0.0.1 es01 infinity mysql minio redis
```
**docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).

View File

@ -252,7 +252,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
`/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`
```
127.0.0.1 es01 mysql minio redis
127.0.0.1 es01 infinity mysql minio redis
```
在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`es 端口更新为 `1200`

View File

@ -529,13 +529,14 @@ def list_chunks():
return get_json_result(
data=False, message="Can't find doc_name or doc_id"
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id)
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
"doc_name": res_item["docnm_kwd"],
"img_id": res_item["img_id"]
"image_id": res_item["img_id"]
} for res_item in res
]

View File

@ -18,12 +18,10 @@ import json
from flask import request
from flask_login import login_required, current_user
from elasticsearch_dsl import Q
from api.db.services.dialog_service import keyword_extraction
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils import rmSpace
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -31,12 +29,11 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService
from api.settings import RetCode, retrievaler, kg_retrievaler
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
from api.utils.api_utils import get_json_result
import hashlib
import re
@manager.route('/list', methods=['POST'])
@login_required
@validate_request("doc_id")
@ -53,12 +50,13 @@ def list_chunk():
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(message="Document not found!")
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@ -69,16 +67,12 @@ def list_chunk():
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"image_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", "").split("\t")
"positions": json.loads(sres.field[id].get("position_list", "[]")),
}
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
assert isinstance(d["positions"], list)
assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
res["chunks"].append(d)
return get_json_result(data=res)
except Exception as e:
@ -96,22 +90,20 @@ def get():
tenants = UserTenantService.query(user_id=current_user.id)
if not tenants:
return get_data_error_result(message="Tenant not found!")
res = ELASTICSEARCH.get(
chunk_id, search.index_name(
tenants[0].tenant_id))
if not res.get("found"):
tenant_id = tenants[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None:
return server_error_response("Chunk not found")
id = res["_id"]
res = res["_source"]
res["chunk_id"] = id
k = []
for n in res.keys():
for n in chunk.keys():
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
k.append(n)
for n in k:
del res[n]
del chunk[n]
return get_json_result(data=res)
return get_json_result(data=chunk)
except Exception as e:
if str(e).find("NotFoundError") >= 0:
return get_json_result(data=False, message='Chunk not found!',
@ -162,7 +154,7 @@ def set():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -174,11 +166,11 @@ def set():
def switch():
req = request.json
try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
search.index_name(tenant_id)):
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
search.index_name(doc.tenant_id), doc.kb_id):
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
except Exception as e:
@ -191,12 +183,11 @@ def switch():
def rm():
req = request.json
try:
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
return get_data_error_result(message="Index updating failure")
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
return get_data_error_result(message="Index updating failure")
deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
@ -239,7 +230,7 @@ def create():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
@ -256,8 +247,9 @@ def retrieval_test():
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
kb_id = req["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id]
kb_ids = req["kb_id"]
if isinstance(kb_ids, str):
kb_ids = [kb_ids]
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
@ -265,17 +257,17 @@ def retrieval_test():
try:
tenants = UserTenantService.query(user_id=current_user.id)
for kid in kb_id:
for kb_id in kb_ids:
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kid):
tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_data_error_result(message="Knowledgebase not found!")
@ -290,7 +282,7 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
for c in ranks["chunks"]:
@ -309,12 +301,16 @@ def retrieval_test():
@login_required
def knowledge_graph():
doc_id = request.args["doc_id"]
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(message="Document not found!")
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = {
"doc_ids":[doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
tenant_id = DocumentService.get_tenant_id(doc_id)
sres = retrievaler.search(req, search.index_name(tenant_id))
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids, doc.kb_id)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -17,7 +17,6 @@ import pathlib
import re
import flask
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user
@ -27,14 +26,13 @@ from api.db.services.file_service import FileService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import UserTenantService
from rag.nlp import search
from rag.utils.es_conn import ELASTICSEARCH
from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.settings import RetCode
from api.settings import RetCode, docStoreConn
from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL
from api.utils.file_utils import filename_type, thumbnail
@ -275,18 +273,8 @@ def change_status():
return get_data_error_result(
message="Database error (Document update)!")
if str(req["status"]) == "0":
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts="ctx._source.available_int=0;",
idxnm=search.index_name(
kb.tenant_id)
)
else:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts="ctx._source.available_int=1;",
idxnm=search.index_name(
kb.tenant_id)
)
status = int(req["status"])
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -365,8 +353,11 @@ def run():
tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id])
@ -490,8 +481,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:

View File

@ -28,6 +28,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
from api.settings import RetCode
from api.utils.api_utils import get_json_result
from api.settings import docStoreConn
from rag.nlp import search
@manager.route('/create', methods=['post'])
@ -166,6 +168,9 @@ def rm():
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@ -30,7 +30,6 @@ from api.db.services.task_service import TaskService, queue_tasks
from api.utils.api_utils import server_error_response
from api.utils.api_utils import get_result, get_error_data_result
from io import BytesIO
from elasticsearch_dsl import Q
from flask import request, send_file
from api.db import FileSource, TaskStatus, FileType
from api.db.db_models import File
@ -42,7 +41,7 @@ from api.settings import RetCode, retrievaler
from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search
from rag.utils import rmSpace
from rag.utils.es_conn import ELASTICSEARCH
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL
import os
@ -293,9 +292,7 @@ def update_doc(tenant_id, dataset_id, document_id):
)
if not e:
return get_error_data_result(message="Document not found!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)
)
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result()
@ -647,9 +644,7 @@ def parse(tenant_id, dataset_id):
info["chunk_num"] = 0
info["token_num"] = 0
DocumentService.update_by_id(id, info)
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
)
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
@ -713,9 +708,7 @@ def stop_parsing(tenant_id, dataset_id):
)
info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info)
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
)
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result()
@ -812,7 +805,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
"question": question,
"sort": True,
}
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "dataset_id",
@ -833,51 +825,56 @@ def list_chunks(tenant_id, dataset_id, document_id):
renamed_doc[new_key] = value
if key == "run":
renamed_doc["run"] = run_mapping.get(str(value))
res = {"total": sres.total, "chunks": [], "doc": renamed_doc}
origin_chunks = []
sign = 0
for id in sres.ids:
d = {
"chunk_id": id,
"content_with_weight": (
rmSpace(sres.highlight[id])
if question and id in sres.highlight
else sres.field[id].get("content_with_weight", "")
),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", "").split("\t"),
}
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append(
[
float(d["positions"][i]),
float(d["positions"][i + 1]),
float(d["positions"][i + 2]),
float(d["positions"][i + 3]),
float(d["positions"][i + 4]),
]
)
d["positions"] = poss
origin_chunks.append(d)
res = {"total": 0, "chunks": [], "doc": renamed_doc}
origin_chunks = []
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
sign = 0
for id in sres.ids:
d = {
"id": id,
"content_with_weight": (
rmSpace(sres.highlight[id])
if question and id in sres.highlight
else sres.field[id].get("content_with_weight", "")
),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", "").split("\t"),
}
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append(
[
float(d["positions"][i]),
float(d["positions"][i + 1]),
float(d["positions"][i + 2]),
float(d["positions"][i + 3]),
float(d["positions"][i + 4]),
]
)
d["positions"] = poss
origin_chunks.append(d)
if req.get("id"):
if req.get("id") == id:
origin_chunks.clear()
origin_chunks.append(d)
sign = 1
break
if req.get("id"):
if req.get("id") == id:
origin_chunks.clear()
origin_chunks.append(d)
sign = 1
break
if req.get("id"):
if sign == 0:
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
if sign == 0:
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
for chunk in origin_chunks:
key_mapping = {
"chunk_id": "id",
"id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
@ -996,9 +993,9 @@ def add_chunk(tenant_id, dataset_id, document_id):
)
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
d["kb_id"] = [doc.kb_id]
d["kb_id"] = dataset_id
d["docnm_kwd"] = doc.name
d["doc_id"] = doc.id
d["doc_id"] = document_id
embd_id = DocumentService.get_embd_id(document_id)
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id
@ -1006,14 +1003,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"]])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
d["chunk_id"] = chunk_id
d["kb_id"] = doc.kb_id
# rename keys
key_mapping = {
"chunk_id": "id",
"id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
@ -1079,36 +1074,16 @@ def rm_chunk(tenant_id, dataset_id, document_id):
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc:
return get_error_data_result(
message=f"You don't own the document {document_id}."
)
doc = doc[0]
req = request.json
if not req.get("chunk_ids"):
return get_error_data_result("`chunk_ids` is required")
query = {"doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True}
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
if not req:
chunk_ids = None
else:
chunk_ids = req.get("chunk_ids")
if not chunk_ids:
chunk_list = sres.ids
else:
chunk_list = chunk_ids
for chunk_id in chunk_list:
if chunk_id not in sres.ids:
return get_error_data_result(f"Chunk {chunk_id} not found")
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=chunk_list), search.index_name(tenant_id)
):
return get_error_data_result(message="Index updating failure")
deleted_chunk_ids = chunk_list
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
return get_result()
condition = {"doc_id": document_id}
if "chunk_ids" in req:
condition["id"] = req["chunk_ids"]
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(req["chunk_ids"])}")
return get_result(message=f"deleted {chunk_number} chunks")
@manager.route(
@ -1168,9 +1143,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema:
type: object
"""
try:
res = ELASTICSEARCH.get(chunk_id, search.index_name(tenant_id))
except Exception:
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
@ -1180,19 +1154,12 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
message=f"You don't own the document {document_id}."
)
doc = doc[0]
query = {
"doc_ids": [document_id],
"page": 1,
"size": 1024,
"question": "",
"sort": True,
}
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
if chunk_id not in sres.ids:
return get_error_data_result(f"You don't own the chunk {chunk_id}")
req = request.json
content = res["_source"].get("content_with_weight")
d = {"id": chunk_id, "content_with_weight": req.get("content", content)}
if "content" in req:
content = req["content"]
else:
content = chunk.get("content_with_weight", "")
d = {"id": chunk_id, "content_with_weight": content}
d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_keywords" in req:
@ -1220,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result()

View File

@ -31,7 +31,7 @@ from api.utils.api_utils import (
generate_confirmation_token,
)
from api.versions import get_rag_version
from rag.utils.es_conn import ELASTICSEARCH
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
from timeit import default_timer as timer
@ -98,10 +98,11 @@ def status():
res = {}
st = timer()
try:
res["es"] = ELASTICSEARCH.health()
res["es"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
res["doc_store"] = docStoreConn.health()
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e:
res["es"] = {
res["doc_store"] = {
"type": "unknown",
"status": "red",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
"error": str(e),

View File

@ -470,7 +470,7 @@ class User(DataBaseModel, UserMixin):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
@ -525,7 +525,7 @@ class Tenant(DataBaseModel):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
@ -542,7 +542,7 @@ class UserTenant(DataBaseModel):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
@ -559,7 +559,7 @@ class InvitationCode(DataBaseModel):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
@ -582,7 +582,7 @@ class LLMFactories(DataBaseModel):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
@ -616,7 +616,7 @@ class LLM(DataBaseModel):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
@ -703,7 +703,7 @@ class Knowledgebase(DataBaseModel):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
@ -767,7 +767,7 @@ class Document(DataBaseModel):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
@ -904,7 +904,7 @@ class Dialog(DataBaseModel):
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
@ -987,7 +987,7 @@ def migrate_db():
help_text="where dose this document come from",
index=True))
)
except Exception as e:
except Exception:
pass
try:
migrate(
@ -996,7 +996,7 @@ def migrate_db():
help_text="default rerank model ID"))
)
except Exception as e:
except Exception:
pass
try:
migrate(
@ -1004,59 +1004,59 @@ def migrate_db():
help_text="default rerank model ID"))
)
except Exception as e:
except Exception:
pass
try:
migrate(
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
)
except Exception as e:
except Exception:
pass
try:
migrate(
migrator.alter_column_type('tenant_llm', 'api_key',
CharField(max_length=1024, null=True, help_text="API KEY", index=True))
)
except Exception as e:
except Exception:
pass
try:
migrate(
migrator.add_column('api_token', 'source',
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
)
except Exception as e:
except Exception:
pass
try:
migrate(
migrator.add_column("tenant","tts_id",
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
)
except Exception as e:
except Exception:
pass
try:
migrate(
migrator.add_column('api_4_conversation', 'source',
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
)
except Exception as e:
except Exception:
pass
try:
DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
except Exception as e:
except Exception:
pass
try:
migrate(
migrator.add_column('task', 'retry_count', IntegerField(default=0))
)
except Exception as e:
except Exception:
pass
try:
migrate(
migrator.alter_column_type('api_token', 'dialog_id',
CharField(max_length=32, null=True, index=True))
)
except Exception as e:
except Exception:
pass

View File

@ -15,7 +15,6 @@
#
import hashlib
import json
import os
import random
import re
import traceback
@ -24,16 +23,13 @@ from copy import deepcopy
from datetime import datetime
from io import BytesIO
from elasticsearch_dsl import Q
from peewee import fn
from api.db.db_utils import bulk_insert_into_db
from api.settings import stat_logger
from api.settings import stat_logger, docStoreConn
from api.utils import current_timestamp, get_format_time, get_uuid
from api.utils.file_utils import get_project_base_directory
from graphrag.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search, rag_tokenizer
@ -112,8 +108,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
cls.clear_chunk_num(doc.id)
return cls.delete_by_id(doc.id)
@ -225,6 +220,15 @@ class DocumentService(CommonService):
return
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def get_knowledgebase_id(cls, doc_id):
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
docs = docs.dicts()
if not docs:
return
return docs[0]["kb_id"]
@classmethod
@DB.connection_context()
def get_tenant_id_by_name(cls, name):
@ -438,11 +442,6 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if not e:
raise LookupError("Can't find this knowledgebase!")
idxnm = search.index_name(kb.tenant_id)
if not ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
err, files = FileService.upload_document(kb, file_objs, user_id)
@ -486,7 +485,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["id"] = md5.hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
@ -499,8 +498,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
else:
d["image"].save(output_buffer, format='JPEG')
STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["_id"])
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"])
del d["image"]
docs.append(d)
@ -520,6 +519,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
token_counts[doc_id] += c
return vects
idxnm = search.index_name(kb.tenant_id)
try_create_idx = True
_, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
for doc_id in docids:
@ -550,7 +552,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
v = vects[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
if try_create_idx:
if not docStoreConn.indexExist(idxnm, kb_id):
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

View File

@ -66,6 +66,16 @@ class KnowledgebaseService(CommonService):
return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_kb_ids(cls, tenant_id):
fields = [
cls.model.id,
]
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
kb_ids = [kb["id"] for kb in kbs]
return kb_ids
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):

View File

@ -18,6 +18,8 @@ from datetime import date
from enum import IntEnum, Enum
from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger
import rag.utils.es_conn
import rag.utils.infinity_conn
# Logger
LoggerFactory.set_directory(
@ -33,7 +35,7 @@ access_logger = getLogger("access")
database_logger = getLogger("database")
chat_logger = getLogger("chat")
from rag.utils.es_conn import ELASTICSEARCH
import rag.utils
from rag.nlp import search
from graphrag import search as kg_search
from api.utils import get_base_config, decrypt_database_config
@ -206,8 +208,12 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False
retrievaler = search.Dealer(ELASTICSEARCH)
kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
if 'username' in get_base_config("es", {}):
docStoreConn = rag.utils.es_conn.ESConnection()
else:
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
retrievaler = search.Dealer(docStoreConn)
kg_retrievaler = kg_search.KGSearch(docStoreConn)
class CustomEnum(Enum):

View File

@ -126,10 +126,6 @@ def server_error_response(e):
if len(e.args) > 1:
return get_json_result(
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
@ -270,10 +266,6 @@ def construct_error_response(e):
pass
if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
@ -295,7 +287,7 @@ def token_required(func):
return decorated_function
def get_result(code=RetCode.SUCCESS, message='error', data=None):
def get_result(code=RetCode.SUCCESS, message="", data=None):
if code == 0:
if data is not None:
response = {"code": code, "data": data}

View File

@ -0,0 +1,26 @@
{
"id": {"type": "varchar", "default": ""},
"doc_id": {"type": "varchar", "default": ""},
"kb_id": {"type": "varchar", "default": ""},
"create_time": {"type": "varchar", "default": ""},
"create_timestamp_flt": {"type": "float", "default": 0.0},
"img_id": {"type": "varchar", "default": ""},
"docnm_kwd": {"type": "varchar", "default": ""},
"title_tks": {"type": "varchar", "default": ""},
"title_sm_tks": {"type": "varchar", "default": ""},
"name_kwd": {"type": "varchar", "default": ""},
"important_kwd": {"type": "varchar", "default": ""},
"important_tks": {"type": "varchar", "default": ""},
"content_with_weight": {"type": "varchar", "default": ""},
"content_ltks": {"type": "varchar", "default": ""},
"content_sm_ltks": {"type": "varchar", "default": ""},
"page_num_list": {"type": "varchar", "default": ""},
"top_list": {"type": "varchar", "default": ""},
"position_list": {"type": "varchar", "default": ""},
"weight_int": {"type": "integer", "default": 0},
"weight_flt": {"type": "float", "default": 0.0},
"rank_int": {"type": "integer", "default": 0},
"available_int": {"type": "integer", "default": 1},
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
"entities_kwd": {"type": "varchar", "default": ""}
}

View File

@ -1,200 +1,203 @@
{
{
"settings": {
"index": {
"number_of_shards": 2,
"number_of_replicas": 0,
"refresh_interval" : "1000ms"
"refresh_interval": "1000ms"
},
"similarity": {
"scripted_sim": {
"type": "scripted",
"script": {
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
}
"scripted_sim": {
"type": "scripted",
"script": {
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
}
}
}
},
"mappings": {
"properties": {
"lat_lon": {"type": "geo_point", "store":"true"}
},
"date_detection": "true",
"dynamic_templates": [
{
"int": {
"match": "*_int",
"mapping": {
"type": "integer",
"store": "true"
}
"properties": {
"lat_lon": {
"type": "geo_point",
"store": "true"
}
},
"date_detection": "true",
"dynamic_templates": [
{
"int": {
"match": "*_int",
"mapping": {
"type": "integer",
"store": "true"
}
},
{
"ulong": {
"match": "*_ulong",
"mapping": {
"type": "unsigned_long",
"store": "true"
}
}
},
{
"long": {
"match": "*_long",
"mapping": {
"type": "long",
"store": "true"
}
}
},
{
"short": {
"match": "*_short",
"mapping": {
"type": "short",
"store": "true"
}
}
},
{
"numeric": {
"match": "*_flt",
"mapping": {
"type": "float",
"store": true
}
}
},
{
"tks": {
"match": "*_tks",
"mapping": {
"type": "text",
"similarity": "scripted_sim",
"analyzer": "whitespace",
"store": true
}
}
},
{
"ltks":{
"match": "*_ltks",
"mapping": {
"type": "text",
"analyzer": "whitespace",
"store": true
}
}
},
{
"kwd": {
"match_pattern": "regex",
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
"mapping": {
"type": "keyword",
"similarity": "boolean",
"store": true
}
}
},
{
"dt": {
"match_pattern": "regex",
"match": "^.*(_dt|_time|_at)$",
"mapping": {
"type": "date",
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
"store": true
}
}
},
{
"nested": {
"match": "*_nst",
"mapping": {
"type": "nested"
}
}
},
{
"object": {
"match": "*_obj",
"mapping": {
"type": "object",
"dynamic": "true"
}
}
},
{
"string": {
"match": "*_with_weight",
"mapping": {
"type": "text",
"index": "false",
"store": true
}
}
},
{
"string": {
"match": "*_fea",
"mapping": {
"type": "rank_feature"
}
}
},
{
"dense_vector": {
"match": "*_512_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 512
}
}
},
{
"dense_vector": {
"match": "*_768_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 768
}
}
},
{
"dense_vector": {
"match": "*_1024_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 1024
}
}
},
{
"dense_vector": {
"match": "*_1536_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 1536
}
}
},
{
"binary": {
"match": "*_bin",
"mapping": {
"type": "binary"
}
}
}
]
}
},
{
"ulong": {
"match": "*_ulong",
"mapping": {
"type": "unsigned_long",
"store": "true"
}
}
},
{
"long": {
"match": "*_long",
"mapping": {
"type": "long",
"store": "true"
}
}
},
{
"short": {
"match": "*_short",
"mapping": {
"type": "short",
"store": "true"
}
}
},
{
"numeric": {
"match": "*_flt",
"mapping": {
"type": "float",
"store": true
}
}
},
{
"tks": {
"match": "*_tks",
"mapping": {
"type": "text",
"similarity": "scripted_sim",
"analyzer": "whitespace",
"store": true
}
}
},
{
"ltks": {
"match": "*_ltks",
"mapping": {
"type": "text",
"analyzer": "whitespace",
"store": true
}
}
},
{
"kwd": {
"match_pattern": "regex",
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
"mapping": {
"type": "keyword",
"similarity": "boolean",
"store": true
}
}
},
{
"dt": {
"match_pattern": "regex",
"match": "^.*(_dt|_time|_at)$",
"mapping": {
"type": "date",
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
"store": true
}
}
},
{
"nested": {
"match": "*_nst",
"mapping": {
"type": "nested"
}
}
},
{
"object": {
"match": "*_obj",
"mapping": {
"type": "object",
"dynamic": "true"
}
}
},
{
"string": {
"match": "*_(with_weight|list)$",
"mapping": {
"type": "text",
"index": "false",
"store": true
}
}
},
{
"string": {
"match": "*_fea",
"mapping": {
"type": "rank_feature"
}
}
},
{
"dense_vector": {
"match": "*_512_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 512
}
}
},
{
"dense_vector": {
"match": "*_768_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 768
}
}
},
{
"dense_vector": {
"match": "*_1024_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 1024
}
}
},
{
"dense_vector": {
"match": "*_1536_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 1536
}
}
},
{
"binary": {
"match": "*_bin",
"mapping": {
"type": "binary"
}
}
}
]
}
}

View File

@ -19,6 +19,11 @@ KIBANA_PASSWORD=infini_rag_flow
# Update it according to the available memory in the host machine.
MEM_LIMIT=8073741824
# Port to expose Infinity API to the host
INFINITY_THRIFT_PORT=23817
INFINITY_HTTP_PORT=23820
INFINITY_PSQL_PORT=5432
# The password for MySQL.
# When updated, you must revise the `mysql.password` entry in service_conf.yaml.
MYSQL_PASSWORD=infini_rag_flow

View File

@ -6,6 +6,7 @@ services:
- esdata01:/usr/share/elasticsearch/data
ports:
- ${ES_PORT}:9200
env_file: .env
environment:
- node.name=es01
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
@ -27,12 +28,40 @@ services:
retries: 120
networks:
- ragflow
restart: always
restart: on-failure
# infinity:
# container_name: ragflow-infinity
# image: infiniflow/infinity:v0.5.0-dev2
# volumes:
# - infinity_data:/var/infinity
# ports:
# - ${INFINITY_THRIFT_PORT}:23817
# - ${INFINITY_HTTP_PORT}:23820
# - ${INFINITY_PSQL_PORT}:5432
# env_file: .env
# environment:
# - TZ=${TIMEZONE}
# mem_limit: ${MEM_LIMIT}
# ulimits:
# nofile:
# soft: 500000
# hard: 500000
# networks:
# - ragflow
# healthcheck:
# test: ["CMD", "curl", "http://localhost:23820/admin/node/current"]
# interval: 10s
# timeout: 10s
# retries: 120
# restart: on-failure
mysql:
# mysql:5.7 linux/arm64 image is unavailable.
image: mysql:8.0.39
container_name: ragflow-mysql
env_file: .env
environment:
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
- TZ=${TIMEZONE}
@ -55,7 +84,7 @@ services:
interval: 10s
timeout: 10s
retries: 3
restart: always
restart: on-failure
minio:
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
@ -64,6 +93,7 @@ services:
ports:
- ${MINIO_PORT}:9000
- ${MINIO_CONSOLE_PORT}:9001
env_file: .env
environment:
- MINIO_ROOT_USER=${MINIO_USER}
- MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
@ -72,25 +102,28 @@ services:
- minio_data:/data
networks:
- ragflow
restart: always
restart: on-failure
redis:
image: valkey/valkey:8
container_name: ragflow-redis
command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
env_file: .env
ports:
- ${REDIS_PORT}:6379
volumes:
- redis_data:/data
networks:
- ragflow
restart: always
restart: on-failure
volumes:
esdata01:
driver: local
infinity_data:
driver: local
mysql_data:
driver: local
minio_data:

View File

@ -1,6 +1,5 @@
include:
- path: ./docker-compose-base.yml
env_file: ./.env
- ./docker-compose-base.yml
services:
ragflow:
@ -15,19 +14,21 @@ services:
- ${SVR_HTTP_PORT}:9380
- 80:80
- 443:443
- 5678:5678
volumes:
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
- ./ragflow-logs:/ragflow/logs
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
env_file: .env
environment:
- TZ=${TIMEZONE}
- HF_ENDPOINT=${HF_ENDPOINT}
- MACOS=${MACOS}
networks:
- ragflow
restart: always
restart: on-failure
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts:
- "host.docker.internal:host-gateway"

View File

@ -67,7 +67,7 @@ docker compose -f docker/docker-compose-base.yml up -d
1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
```
127.0.0.1 es01 mysql minio redis
127.0.0.1 es01 infinity mysql minio redis
```
2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.

View File

@ -1280,7 +1280,7 @@ Success:
"document_keyword": "1.txt",
"highlight": "<em>ragflow</em> content",
"id": "d78435d142bd5cf6704da62c778795c5",
"img_id": "",
"image_id": "",
"important_keywords": [
""
],

View File

@ -1351,7 +1351,7 @@ A list of `Chunk` objects representing references to the message, each containin
The chunk ID.
- `content` `str`
The content of the chunk.
- `image_id` `str`
- `img_id` `str`
The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
- `document_id` `str`
The ID of the referenced document.

View File

@ -254,9 +254,12 @@ if __name__ == "__main__":
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])]
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
info = {
"input_text": docs,
"entity_specs": "organization, person",

View File

@ -15,95 +15,90 @@
#
import json
from copy import deepcopy
from typing import Dict
import pandas as pd
from elasticsearch_dsl import Q, Search
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
from rag.nlp.search import Dealer
class KGSearch(Dealer):
def search(self, req, idxnm, emb_mdl=None, highlight=False):
def merge_into_first(sres, title=""):
df,texts = [],[]
for d in sres["hits"]["hits"]:
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
def merge_into_first(sres, title="") -> Dict[str, str]:
if not sres:
return {}
content_with_weight = ""
df, texts = [],[]
for d in sres.values():
try:
df.append(json.loads(d["_source"]["content_with_weight"]))
except Exception as e:
texts.append(d["_source"]["content_with_weight"])
pass
if not df and not texts: return False
df.append(json.loads(d["content_with_weight"]))
except Exception:
texts.append(d["content_with_weight"])
if df:
try:
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
except Exception as e:
pass
content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
else:
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
return True
content_with_weight = title + "\n" + "\n".join(texts)
first_id = ""
first_source = {}
for k, v in sres.items():
first_id = id
first_source = deepcopy(v)
break
first_source["content_with_weight"] = content_with_weight
first_id = next(iter(sres))
return {first_id: first_source}
qst = req.get("question", "")
matchText, keywords = self.qryr.question(qst, min_match=0.05)
condition = self.get_filters(req)
## Entity retrieval
condition.update({"knowledge_graph_kwd": ["entity"]})
assert emb_mdl, "No embedding model selected"
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
"doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
"weight_int", "weight_flt", "rank_int"
])
qst = req.get("question", "")
binary_query, keywords = self.qryr.question(qst, min_match="5%")
binary_query = self._add_filters(binary_query, req)
fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
## Entity retrieval
bqry = deepcopy(binary_query)
bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
s = Search()
s = s.query(bqry)[0: 32]
s = s.to_dict()
q_vec = []
if req.get("vector"):
assert emb_mdl, "No embedding model selected"
s["knn"] = self._vector(
qst, emb_mdl, req.get(
"similarity", 0.1), 1024)
s["knn"]["filter"] = bqry.to_dict()
q_vec = s["knn"]["query_vector"]
ent_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
ent_ids = self.es.getDocIds(ent_res)
if merge_into_first(ent_res, "-Entities-"):
ent_ids = ent_ids[0:1]
ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
ent_res_fields = self.dataStore.getFields(ent_res, src)
entities = [d["name_kwd"] for d in ent_res_fields.values()]
ent_ids = self.dataStore.getChunkIds(ent_res)
ent_content = merge_into_first(ent_res_fields, "-Entities-")
if ent_content:
ent_ids = list(ent_content.keys())
## Community retrieval
bqry = deepcopy(binary_query)
bqry.filter.append(Q("terms", entities_kwd=entities))
bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
s = Search()
s = s.query(bqry)[0: 32]
s = s.to_dict()
comm_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
comm_ids = self.es.getDocIds(comm_res)
if merge_into_first(comm_res, "-Community Report-"):
comm_ids = comm_ids[0:1]
condition = self.get_filters(req)
condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, src)
comm_ids = self.dataStore.getChunkIds(comm_res)
comm_content = merge_into_first(comm_res_fields, "-Community Report-")
if comm_content:
comm_ids = list(comm_content.keys())
## Text content retrieval
bqry = deepcopy(binary_query)
bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
s = Search()
s = s.query(bqry)[0: 6]
s = s.to_dict()
txt_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
txt_ids = self.es.getDocIds(txt_res)
if merge_into_first(txt_res, "-Original Content-"):
txt_ids = txt_ids[0:1]
condition = self.get_filters(req)
condition.update({"knowledge_graph_kwd": ["text"]})
txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
txt_res_fields = self.dataStore.getFields(txt_res, src)
txt_ids = self.dataStore.getChunkIds(txt_res)
txt_content = merge_into_first(txt_res_fields, "-Original Content-")
if txt_content:
txt_ids = list(txt_content.keys())
return self.SearchResult(
total=len(ent_ids) + len(comm_ids) + len(txt_ids),
ids=[*ent_ids, *comm_ids, *txt_ids],
query_vector=q_vec,
aggregation=None,
highlight=None,
field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
field={**ent_content, **comm_content, **txt_content},
keywords=[]
)

View File

@ -31,10 +31,13 @@ if __name__ == "__main__":
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in
retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=6, fields=["content_with_weight"])]
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
graph = ex(docs)
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))

871
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -46,22 +46,23 @@ hanziconv = "0.3.2"
html-text = "0.6.2"
httpx = "0.27.0"
huggingface-hub = "^0.25.0"
infinity-emb = "0.0.51"
infinity-sdk = "0.5.0.dev2"
infinity-emb = "^0.0.66"
itsdangerous = "2.1.2"
markdown = "3.6"
markdown-to-json = "2.1.1"
minio = "7.2.4"
mistralai = "0.4.2"
nltk = "3.9.1"
numpy = "1.26.4"
numpy = "^1.26.0"
ollama = "0.2.1"
onnxruntime = "1.19.2"
openai = "1.45.0"
opencv-python = "4.10.0.84"
opencv-python-headless = "4.10.0.84"
openpyxl = "3.1.2"
openpyxl = "^3.1.0"
ormsgpack = "1.5.0"
pandas = "2.2.2"
pandas = "^2.2.0"
pdfplumber = "0.10.4"
peewee = "3.17.1"
pillow = "10.4.0"
@ -70,7 +71,7 @@ psycopg2-binary = "2.9.9"
pyclipper = "1.3.0.post5"
pycryptodomex = "3.20.0"
pypdf = "^5.0.0"
pytest = "8.2.2"
pytest = "^8.3.0"
python-dotenv = "1.0.1"
python-dateutil = "2.8.2"
python-pptx = "^1.0.2"
@ -86,7 +87,7 @@ ruamel-base = "1.0.0"
scholarly = "1.7.11"
scikit-learn = "1.5.0"
selenium = "4.22.0"
setuptools = "70.0.0"
setuptools = "^75.2.0"
shapely = "2.0.5"
six = "1.16.0"
strenum = "0.4.15"
@ -115,6 +116,7 @@ pymysql = "^1.1.1"
mini-racer = "^0.12.4"
pyicu = "^2.13.1"
flasgger = "^0.9.7.1"
polars = "^1.9.0"
[tool.poetry.group.full]

View File

@ -20,6 +20,7 @@ from rag.nlp import tokenize, is_english
from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, PptParser, PlainParser
from PyPDF2 import PdfReader as pdf2_read
import json
class Ppt(PptParser):
@ -107,9 +108,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
d = copy.deepcopy(doc)
pn += from_page
d["image"] = img
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
d["page_num_list"] = json.dumps([pn + 1])
d["top_list"] = json.dumps([0])
d["position_list"] = json.dumps([(pn + 1, 0, img.size[0], 0, img.size[1])])
tokenize(d, txt, eng)
res.append(d)
return res
@ -123,10 +124,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
pn += from_page
if img:
d["image"] = img
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
d["page_num_list"] = json.dumps([pn + 1])
d["top_list"] = json.dumps([0])
d["position_list"] = json.dumps([
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)])
tokenize(d, txt, eng)
res.append(d)
return res

View File

@ -74,7 +74,7 @@ class Excel(ExcelParser):
def trans_datatime(s):
try:
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
except Exception:
pass
@ -112,7 +112,7 @@ def column_data_type(arr):
continue
try:
arr[i] = trans[ty](str(arr[i]))
except Exception as e:
except Exception:
arr[i] = None
# if ty == "text":
# if len(arr) > 128 and uni / len(arr) < 0.1:
@ -182,7 +182,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
"datetime": "_dt",
"bool": "_kwd"}
for df in dfs:
for n in ["id", "_id", "index", "idx"]:
for n in ["id", "index", "idx"]:
if n in df.columns:
del df[n]
clmns = df.columns.values

View File

@ -15,50 +15,51 @@
#
import json
import os
import sys
import time
import argparse
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import retrievaler
from api.settings import retrievaler, docStoreConn
from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory
from rag.nlp import tokenize, search
from rag.utils.es_conn import ELASTICSEARCH
from ranx import evaluate
import pandas as pd
from tqdm import tqdm
from ranx import Qrels, Run
global max_docs
max_docs = sys.maxsize
class Benchmark:
def __init__(self, kb_id):
self.kb_id = kb_id
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
self.similarity_threshold = self.kb.similarity_threshold
self.vector_similarity_weight = self.kb.vector_similarity_weight
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
self.tenant_id = ''
self.index_name = ''
self.initialized_index = False
def _get_benchmarks(self, query, dataset_idxnm, count=16):
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
return sres
def _get_retrieval(self, qrels, dataset_idxnm):
def _get_retrieval(self, qrels):
# Need to wait for the ES and Infinity index to be ready
time.sleep(20)
run = defaultdict(dict)
query_list = list(qrels.keys())
for query in query_list:
ranks = retrievaler.retrieval(query, self.embd_mdl,
dataset_idxnm, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight)
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}")
del qrels[query]
continue
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
run[query][c["chunk_id"]] = c["similarity"]
return run
def embedding(self, docs, batch_size=16):
@ -68,40 +69,37 @@ class Benchmark:
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
assert len(docs) == len(vects)
vector_size = 0
for i, d in enumerate(docs):
v = vects[i]
vector_size = len(v)
d["q_%d_vec" % len(v)] = v
return docs
return docs, vector_size
@staticmethod
def init_kb(index_name):
idxnm = search.index_name(index_name)
if ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.deleteIdx(search.index_name(index_name))
return ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
def init_index(self, vector_size: int):
if self.initialized_index:
return
if docStoreConn.indexExist(self.index_name, self.kb_id):
docStoreConn.deleteIdx(self.index_name, self.kb_id)
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True
def ms_marco_index(self, file_path, index_name):
qrels = defaultdict(dict)
texts = defaultdict(dict)
docs_count = 0
docs = []
filelist = os.listdir(file_path)
self.init_kb(index_name)
max_workers = int(os.environ.get('MAX_WORKERS', 3))
exe = ThreadPoolExecutor(max_workers=max_workers)
threads = []
def slow_actions(es_docs, idx_nm):
es_docs = self.embedding(es_docs)
ELASTICSEARCH.bulk(es_docs, idx_nm)
return True
for dir in filelist:
data = pd.read_parquet(os.path.join(file_path, dir))
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
filelist = sorted(os.listdir(file_path))
for fn in filelist:
if docs_count >= max_docs:
break
if not fn.endswith(".parquet"):
continue
data = pd.read_parquet(os.path.join(file_path, fn))
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
if docs_count >= max_docs:
break
query = data.iloc[i]['query']
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
d = {
@ -115,27 +113,33 @@ class Benchmark:
texts[d["id"]] = text
qrels[query][d["id"]] = int(rel)
if len(docs) >= 32:
threads.append(
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = []
threads.append(
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
if not threads[i].result().output:
print("Indexing error...")
if docs:
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts
def trivia_qa_index(self, file_path, index_name):
qrels = defaultdict(dict)
texts = defaultdict(dict)
docs_count = 0
docs = []
filelist = os.listdir(file_path)
for dir in filelist:
data = pd.read_parquet(os.path.join(file_path, dir))
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir):
filelist = sorted(os.listdir(file_path))
for fn in filelist:
if docs_count >= max_docs:
break
if not fn.endswith(".parquet"):
continue
data = pd.read_parquet(os.path.join(file_path, fn))
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
if docs_count >= max_docs:
break
query = data.iloc[i]['question']
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
data.iloc[i]["search_results"]['search_context']):
@ -150,16 +154,18 @@ class Benchmark:
texts[d["id"]] = text
qrels[query][d["id"]] = int(rel)
if len(docs) >= 32:
docs = self.embedding(docs)
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs,self.index_name)
docs = []
docs = self.embedding(docs)
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name):
corpus_total = {}
for corpus_file in os.listdir(corpus_path):
tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
@ -176,14 +182,19 @@ class Benchmark:
qrels = defaultdict(dict)
texts = defaultdict(dict)
docs_count = 0
docs = []
for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
if 'test' in qrels_file:
continue
if docs_count >= max_docs:
break
tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
names=['qid', 'Q0', 'docid', 'relevance'])
for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
if docs_count >= max_docs:
break
query = topics_total[tmp_data.iloc[i]['qid']]
text = corpus_total[tmp_data.iloc[i]['docid']]
rel = tmp_data.iloc[i]['relevance']
@ -198,13 +209,15 @@ class Benchmark:
texts[d["id"]] = text
qrels[query][d["id"]] = int(rel)
if len(docs) >= 32:
docs = self.embedding(docs)
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
docs = []
docs = self.embedding(docs)
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path):
@ -213,7 +226,7 @@ class Benchmark:
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
key = run_keys[run_i]
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
f.write('## Score For Every Query\n')
@ -229,14 +242,18 @@ class Benchmark:
def __call__(self, dataset, file_path, miracl_corpus=''):
if dataset == "ms_marco_v1.1":
self.tenant_id = "benchmark_ms_marco_v11"
self.index_name = search.index_name(self.tenant_id)
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
run = self._get_retrieval(qrels)
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path)
if dataset == "trivia_qa":
self.tenant_id = "benchmark_trivia_qa"
self.index_name = search.index_name(self.tenant_id)
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
run = self._get_retrieval(qrels, "benchmark_trivia_qa")
print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
run = self._get_retrieval(qrels)
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path)
if dataset == "miracl":
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
@ -253,28 +270,41 @@ class Benchmark:
if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
continue
self.tenant_id = "benchmark_miracl_" + lang
self.index_name = search.index_name(self.tenant_id)
self.initialized_index = False
qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
"benchmark_miracl_" + lang)
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
run = self._get_retrieval(qrels)
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path)
if __name__ == '__main__':
print('*****************RAGFlow Benchmark*****************')
kb_id = input('Please input kb_id:\n')
parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
args = parser.parse_args()
max_docs = args.max_docs
kb_id = args.kb_id
ex = Benchmark(kb_id)
dataset = input(
'RAGFlow Benchmark Support:\n\tms_marco_v1.1:<https://huggingface.co/datasets/microsoft/ms_marco>\n\ttrivia_qa:<https://huggingface.co/datasets/mandarjoshi/trivia_qa>\n\tmiracl:<https://huggingface.co/datasets/miracl/miracl>\nPlease input dataset choice:\n')
if dataset in ['ms_marco_v1.1', 'trivia_qa']:
if dataset == "ms_marco_v1.1":
print("Notice: Please provide the ms_marco_v1.1 dataset only. ms_marco_v2.1 is not supported!")
dataset_path = input('Please input ' + dataset + ' dataset path:\n')
dataset = args.dataset
dataset_path = args.dataset_path
if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
ex(dataset, dataset_path)
elif dataset == 'miracl':
dataset_path = input('Please input ' + dataset + ' dataset path:\n')
corpus_path = input('Please input ' + dataset + '-corpus dataset path:\n')
ex(dataset, dataset_path, miracl_corpus=corpus_path)
elif dataset == "miracl":
if len(args) < 5:
print('Please input the correct parameters!')
exit(1)
miracl_corpus_path = args[4]
ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
else:
print("Dataset: ", dataset, "not supported!")

View File

@ -25,6 +25,7 @@ import roman_numbers as r
from word2number import w2n
from cn2an import cn2an
from PIL import Image
import json
all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
@ -51,12 +52,12 @@ def find_codec(blob):
try:
blob[:1024].decode(c)
return c
except Exception as e:
except Exception:
pass
try:
blob.decode(c)
return c
except Exception as e:
except Exception:
pass
return "utf-8"
@ -241,7 +242,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
except NotImplementedError as e:
except NotImplementedError:
pass
tokenize(d, ck, eng)
res.append(d)
@ -289,13 +290,16 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
def add_positions(d, poss):
if not poss:
return
d["page_num_int"] = []
d["position_int"] = []
d["top_int"] = []
page_num_list = []
position_list = []
top_list = []
for pn, left, right, top, bottom in poss:
d["page_num_int"].append(int(pn + 1))
d["top_int"].append(int(top))
d["position_int"].append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
page_num_list.append(int(pn + 1))
top_list.append(int(top))
position_list.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
d["page_num_list"] = json.dumps(page_num_list)
d["position_list"] = json.dumps(position_list)
d["top_list"] = json.dumps(top_list)
def remove_contents_table(sections, eng=False):

View File

@ -15,20 +15,25 @@
#
import json
import math
import re
import logging
import copy
from elasticsearch_dsl import Q
from rag.utils.doc_store_conn import MatchTextExpr
from rag.nlp import rag_tokenizer, term_weight, synonym
class EsQueryer:
def __init__(self, es):
class FulltextQueryer:
def __init__(self):
self.tw = term_weight.Dealer()
self.es = es
self.syn = synonym.Dealer()
self.flds = ["ask_tks^10", "ask_small_tks"]
self.query_fields = [
"title_tks^10",
"title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"content_ltks^2",
"content_sm_ltks",
]
@staticmethod
def subSpecialChar(line):
@ -43,12 +48,15 @@ class EsQueryer:
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1. / len(arr) >= 0.7
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def rmWWW(txt):
patts = [
(r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
(
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
]
@ -56,16 +64,16 @@ class EsQueryer:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
return txt
def question(self, txt, tbl="qa", min_match="60%"):
def question(self, txt, tbl="qa", min_match:float=0.6):
txt = re.sub(
r"[ :\r\n\t,,。??/`!&\^%%]+",
" ",
rag_tokenizer.tradi2simp(
rag_tokenizer.strQ2B(
txt.lower()))).strip()
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt):
txt = EsQueryer.rmWWW(txt)
txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split(" ")
tks_w = self.tw.weights(tks)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
@ -73,14 +81,20 @@ class EsQueryer:
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
for i in range(1, len(tks_w)):
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q:
q.append(txt)
return Q("bool",
must=Q("query_string", fields=self.flds,
type="best_fields", query=" ".join(q),
boost=1)#, minimum_should_match=min_match)
), list(set([t for t in txt.split(" ") if t]))
query = " ".join(q)
return MatchTextExpr(
self.query_fields, query, 100
), tks
def need_fine_grained_tokenize(tk):
if len(tk) < 3:
@ -89,7 +103,7 @@ class EsQueryer:
return False
return True
txt = EsQueryer.rmWWW(txt)
txt = FulltextQueryer.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split(" "):
if not tt:
@ -101,65 +115,71 @@ class EsQueryer:
logging.info(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = rag_tokenizer.fine_grained_tokenize(tk).split(" ") if need_fine_grained_tokenize(tk) else []
sm = (
rag_tokenizer.fine_grained_tokenize(tk).split(" ")
if need_fine_grained_tokenize(tk)
else []
)
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m) for m in sm]
sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
m,
)
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
if len(keywords) >= 12: break
if len(keywords) >= 12:
break
tk_syns = self.syn.lookup(tk)
tk = EsQueryer.subSpecialChar(tk)
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = "\"%s\"" % tk
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} %s)" % " ".join(tk_syns)
if sm:
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
" ".join(sm), " ".join(sm))
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
if tk.strip():
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
tms += ' ("%s"~4)^1.5' % (" ".join([t for t, _ in twts]))
if re.match(r"[0-9a-z ]+$", tt):
tms = f"(\"{tt}\" OR \"%s\")" % rag_tokenizer.tokenize(tt)
tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
syns = " OR ".join(
["\"%s\"^0.7" % EsQueryer.subSpecialChar(rag_tokenizer.tokenize(s)) for s in syns])
[
'"%s"^0.7'
% FulltextQueryer.subSpecialChar(rag_tokenizer.tokenize(s))
for s in syns
]
)
if syns:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
flds = copy.deepcopy(self.flds)
mst = []
if qs:
mst.append(
Q("query_string", fields=flds, type="best_fields",
query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
)
query = " OR ".join([f"({t})" for t in qs if t])
return MatchTextExpr(
self.query_fields, query, 100, {"minimum_should_match": min_match}
), keywords
return None, keywords
return Q("bool",
must=mst,
), list(set(keywords))
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
vtweight=0.7):
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
return np.array(sims[0]) * vtweight + \
np.array(tksim) * tkweight, tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):

View File

@ -14,34 +14,25 @@
# limitations under the License.
#
import json
import re
from copy import deepcopy
from elasticsearch_dsl import Q, Search
import json
from typing import List, Optional, Dict, Union
from dataclasses import dataclass
from rag.settings import es_logger
from rag.settings import doc_store_logger
from rag.utils import rmSpace
from rag.nlp import rag_tokenizer, query, is_english
from rag.nlp import rag_tokenizer, query
import numpy as np
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
def index_name(uid): return f"ragflow_{uid}"
class Dealer:
def __init__(self, es):
self.qryr = query.EsQueryer(es)
self.qryr.flds = [
"title_tks^10",
"title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"content_ltks^2",
"content_sm_ltks"]
self.es = es
def __init__(self, dataStore: DocStoreConnection):
self.qryr = query.FulltextQueryer()
self.dataStore = dataStore
@dataclass
class SearchResult:
@ -54,170 +45,99 @@ class Dealer:
keywords: Optional[List[str]] = None
group_docs: List[List] = None
def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
qv, c = emb_mdl.encode_queries(txt)
return {
"field": "q_%d_vec" % len(qv),
"k": topk,
"similarity": sim,
"num_candidates": topk * 2,
"query_vector": [float(v) for v in qv]
}
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = emb_mdl.encode_queries(txt)
embedding_data = [float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
def _add_filters(self, bqry, req):
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
if req.get("doc_ids"):
bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
if req.get("knowledge_graph_kwd"):
bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"]))
if "available_int" in req:
if req["available_int"] == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry
def get_filters(self, req):
condition = dict()
for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
if key in req and req[key] is not None:
condition[field] = req[key]
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
for key in ["knowledge_graph_kwd"]:
if key in req and req[key] is not None:
condition[key] = req[key]
return condition
def search(self, req, idxnms, emb_mdl=None, highlight=False):
qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst, min_match="30%")
bqry = self._add_filters(bqry, req)
bqry.boost = 0.05
def search(self, req, idx_names: list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
filters = self.get_filters(req)
orderBy = OrderByExpr()
s = Search()
pg = int(req.get("page", 1)) - 1
topk = int(req.get("topk", 1024))
ps = int(req.get("size", topk))
offset, limit = pg * ps, (pg + 1) * ps
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd",
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks")
s = s.highlight("title_ltks")
if not qst:
if not req.get("sort"):
s = s.sort(
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
{"create_timestamp_flt": {
"order": "desc", "unmapped_type": "float"}}
)
else:
s = s.sort(
{"page_num_int": {"order": "asc", "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}},
{"top_int": {"order": "asc", "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}},
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
{"create_timestamp_flt": {
"order": "desc", "unmapped_type": "float"}}
)
if qst:
s = s.highlight_options(
fragment_size=120,
number_of_fragments=5,
boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、"
)
s = s.to_dict()
q_vec = []
if req.get("vector"):
assert emb_mdl, "No embedding model selected"
s["knn"] = self._vector(
qst, emb_mdl, req.get(
"similarity", 0.1), topk)
s["knn"]["filter"] = bqry.to_dict()
if not highlight and "highlight" in s:
del s["highlight"]
q_vec = s["knn"]["query_vector"]
es_logger.info("【Q】: {}".format(json.dumps(s)))
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
if req.get("doc_ids"):
bqry = Q("bool", must=[])
bqry = self._add_filters(bqry, req)
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
es_logger.info("【Q】: {}".format(json.dumps(s)))
"doc_id", "position_list", "knowledge_graph_kwd",
"available_int", "content_with_weight"])
kwds = set([])
for k in keywords:
kwds.add(k)
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
aggs = self.getAggregation(res, "docnm_kwd")
qst = req.get("question", "")
q_vec = []
if not qst:
if req.get("sort"):
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"] if highlight else []
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src.append(f"q_{len(q_vec)}_vec")
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
if total == 0:
matchText, _ = self.qryr.question(qst, min_match=0.1)
if "doc_ids" in filters:
del filters["doc_ids"]
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
kwds.add(k)
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
doc_store_logger.info(f"TOTAL: {total}")
ids=self.dataStore.getChunkIds(res)
keywords=list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total=self.es.getTotal(res),
ids=self.es.getDocIds(res),
total=total,
ids=ids,
query_vector=q_vec,
aggregation=aggs,
highlight=self.getHighlight(res, keywords, "content_with_weight"),
field=self.getFields(res, src),
keywords=list(kwds)
highlight=highlight,
field=self.dataStore.getFields(res, src),
keywords=keywords
)
def getAggregation(self, res, g):
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
return
bkts = res["aggregations"]["aggs_" + g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
def getHighlight(self, res, keywords, fieldnm):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split(" ")):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def getFields(self, sres, flds):
res = {}
if not flds:
return {}
for d in self.es.getSource(sres):
m = {n: d.get(n) for n in flds if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, type([])):
m[n] = "\t".join([str(vv) if not isinstance(
vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v])
continue
if not isinstance(v, type("")):
m[n] = str(m[n])
#if n.find("tks") > 0:
# m[n] = rmSpace(m[n])
if m:
res[d["id"]] = m
return res
@staticmethod
def trans2floats(txt):
return [float(t) for t in txt.split("\t")]
@ -260,7 +180,7 @@ class Dealer:
continue
idx.append(i)
pieces_.append(t)
es_logger.info("{} => {}".format(answer, pieces_))
doc_store_logger.info("{} => {}".format(answer, pieces_))
if not pieces_:
return answer, set([])
@ -281,7 +201,7 @@ class Dealer:
chunks_tks,
tkweight, vtweight)
mx = np.max(sim) * 0.99
es_logger.info("{} SIM: {}".format(pieces_[i], mx))
doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx))
if mx < thr:
continue
cites[idx[i]] = list(
@ -309,9 +229,15 @@ class Dealer:
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"):
_, keywords = self.qryr.question(query)
ins_embd = [
Dealer.trans2floats(
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
vector_size = len(sres.query_vector)
vector_column = f"q_{vector_size}_vec"
zero_vector = [0.0] * vector_size
ins_embd = []
for chunk_id in sres.ids:
vector = sres.field[chunk_id].get(vector_column, zero_vector)
if isinstance(vector, str):
vector = [float(v) for v in vector.split("\t")]
ins_embd.append(vector)
if not ins_embd:
return [], [], []
@ -377,7 +303,7 @@ class Dealer:
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
ranks["total"] = sres.total
if page <= RERANK_PAGE_LIMIT:
@ -393,6 +319,8 @@ class Dealer:
idx = list(range(len(sres.ids)))
dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
for i in idx:
if sim[i] < similarity_threshold:
break
@ -401,34 +329,32 @@ class Dealer:
continue
break
id = sres.ids[i]
dnm = sres.field[id]["docnm_kwd"]
did = sres.field[id]["doc_id"]
chunk = sres.field[id]
dnm = chunk["docnm_kwd"]
did = chunk["doc_id"]
position_list = chunk.get("position_list", "[]")
if not position_list:
position_list = "[]"
d = {
"chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"],
"content_with_weight": sres.field[id]["content_with_weight"],
"doc_id": sres.field[id]["doc_id"],
"content_ltks": chunk["content_ltks"],
"content_with_weight": chunk["content_with_weight"],
"doc_id": chunk["doc_id"],
"docnm_kwd": dnm,
"kb_id": sres.field[id]["kb_id"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"kb_id": chunk["kb_id"],
"important_kwd": chunk.get("important_kwd", []),
"image_id": chunk.get("img_id", ""),
"similarity": sim[i],
"vector_similarity": vsim[i],
"term_similarity": tsim[i],
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
"positions": sres.field[id].get("position_int", "").split("\t")
"vector": chunk.get(vector_column, zero_vector),
"positions": json.loads(position_list)
}
if highlight:
if id in sres.highlight:
d["highlight"] = rmSpace(sres.highlight[id])
else:
d["highlight"] = d["content_with_weight"]
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
@ -442,39 +368,11 @@ class Dealer:
return ranks
def sql_retrieval(self, sql, fetch_size=128, format="json"):
from api.settings import chat_logger
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
es_logger.info(f"Get es sql: {sql}")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
tbl = self.dataStore.sql(sql, fetch_size, format)
return tbl
for p, r in replaces:
sql = sql.replace(p, r, 1)
chat_logger.info(f"To es: {sql}")
try:
tbl = self.es.sql(sql, fetch_size, format)
return tbl
except Exception as e:
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
return {"error": str(e)}
def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
s = Search()
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
s = s.to_dict()
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
res = []
for index, chunk in enumerate(es_res['hits']['hits']):
res.append({fld: chunk['_source'].get(fld) for fld in fields})
return res
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
condition = {"doc_id": doc_id}
res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids)
dict_chunks = self.dataStore.getFields(res, fields)
return dict_chunks.values()

View File

@ -25,12 +25,13 @@ RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
SUBPROCESS_STD_LOG_NAME = "std.log"
ES = get_base_config("es", {})
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
AZURE = get_base_config("azure", {})
S3 = get_base_config("s3", {})
MINIO = decrypt_database_config(name="minio")
try:
REDIS = decrypt_database_config(name="redis")
except Exception as e:
except Exception:
REDIS = {}
pass
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
@ -44,7 +45,7 @@ LoggerFactory.set_directory(
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 30
es_logger = getLogger("es")
doc_store_logger = getLogger("doc_store")
minio_logger = getLogger("minio")
s3_logger = getLogger("s3")
azure_logger = getLogger("azure")
@ -53,7 +54,7 @@ chunk_logger = getLogger("chunk_logger")
database_logger = getLogger("database")
formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
for logger in [es_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
for logger in [doc_store_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
logger.setLevel(logging.INFO)
for handler in logger.handlers:
handler.setFormatter(fmt=formatter)

View File

@ -31,7 +31,6 @@ from timeit import default_timer as timer
import numpy as np
import pandas as pd
from elasticsearch_dsl import Q
from api.db import LLMType, ParserType
from api.db.services.dialog_service import keyword_extraction, question_proposal
@ -39,8 +38,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler
from api.utils.file_utils import get_project_base_directory
from api.settings import retrievaler, docStoreConn
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.nlp import search, rag_tokenizer
@ -48,7 +46,6 @@ from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as
from rag.settings import database_logger, SVR_QUEUE_NAME
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL
@ -126,7 +123,7 @@ def collect():
return pd.DataFrame()
tasks = TaskService.get_tasks(msg["id"])
if not tasks:
cron_logger.warn("{} empty task!".format(msg["id"]))
cron_logger.warning("{} empty task!".format(msg["id"]))
return []
tasks = pd.DataFrame(tasks)
@ -187,7 +184,7 @@ def build(row):
docs = []
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])]
"kb_id": str(row["kb_id"])
}
el = 0
for ck in cks:
@ -196,10 +193,14 @@ def build(row):
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if not d.get("image"):
d["img_id"] = ""
d["page_num_list"] = json.dumps([])
d["position_list"] = json.dumps([])
d["top_list"] = json.dumps([])
docs.append(d)
continue
@ -211,13 +212,13 @@ def build(row):
d["image"].save(output_buffer, format='JPEG')
st = timer()
STORAGE_IMPL.put(row["kb_id"], d["_id"], output_buffer.getvalue())
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st
except Exception as e:
cron_logger.error(str(e))
traceback.print_exc()
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
del d["image"]
docs.append(d)
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
@ -245,12 +246,9 @@ def build(row):
return docs
def init_kb(row):
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
if ELASTICSEARCH.indexExist(idxnm):
return
return ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
@ -288,17 +286,20 @@ def embedding(docs, mdl, parser_config=None, callback=None):
cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(docs)
vector_size = 0
for i, d in enumerate(docs):
v = vects[i].tolist()
vector_size = len(v)
d["q_%d_vec" % len(v)] = v
return tk_count
return tk_count, vector_size
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vts, _ = embd_mdl.encode(["ok"])
vctr_nm = "q_%d_vec" % len(vts[0])
vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size
chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor(
@ -323,7 +324,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
d = copy.deepcopy(doc)
md5 = hashlib.md5()
md5.update((content + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
d[vctr_nm] = vctr.tolist()
@ -332,7 +333,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
res.append(d)
tk_count += num_tokens_from_string(content)
return res, tk_count
return res, tk_count, vector_size
def main():
@ -352,7 +353,7 @@ def main():
if r.get("task_type", "") == "raptor":
try:
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
except Exception as e:
callback(-1, msg=str(e))
cron_logger.error(str(e))
@ -373,7 +374,7 @@ def main():
len(cks))
st = timer()
try:
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
except Exception as e:
callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))
@ -381,26 +382,25 @@ def main():
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
init_kb(r)
chunk_count = len(set([c["_id"] for c in cks]))
# cron_logger.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
init_kb(r, vector_size)
chunk_count = len(set([c["id"] for c in cks]))
st = timer()
es_r = ""
es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size):
es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r:
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
cron_logger.error(str(es_r))
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
cron_logger.error('Insert chunk error: ' + str(es_r))
else:
if TaskService.do_cancel(r["id"]):
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
continue
callback(1., "Done!")
DocumentService.increment_chunk_num(

251
rag/utils/doc_store_conn.py Normal file
View File

@ -0,0 +1,251 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
from dataclasses import dataclass
import numpy as np
import polars as pl
from typing import List, Dict
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = Union[list, np.ndarray]
@dataclass
class SparseVector:
indices: list[int]
values: Union[list[float], list[int], None] = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
def to_dict_old(self):
d = {"indices": self.indices}
if self.values is not None:
d["values"] = self.values
return d
def to_dict(self):
if self.values is None:
raise ValueError("SparseVector.values is None")
result = {}
for i, v in zip(self.indices, self.values):
result[str(i)] = v
return result
@staticmethod
def from_dict(d):
return SparseVector(d["indices"], d.get("values"))
def __str__(self):
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
def __init__(
self,
fields: str,
matching_text: str,
topn: int,
extra_options: dict = dict(),
):
self.fields = fields
self.matching_text = matching_text
self.topn = topn
self.extra_options = extra_options
class MatchDenseExpr(ABC):
def __init__(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
self.embedding_data_type = embedding_data_type
self.distance_type = distance_type
self.topn = topn
self.extra_options = extra_options
class MatchSparseExpr(ABC):
def __init__(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: Optional[dict] = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
self.distance_type = distance_type
self.topn = topn
self.opt_params = opt_params
class MatchTensorExpr(ABC):
def __init__(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
):
self.column_name = column_name
self.query_data = query_data
self.query_data_type = query_data_type
self.topn = topn
self.extra_option = extra_option
class FusionExpr(ABC):
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = Union[
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
]
class OrderByExpr(ABC):
def __init__(self):
self.fields = list()
def asc(self, field: str):
self.fields.append((field, 0))
return self
def desc(self, field: str):
self.fields.append((field, 1))
return self
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
"""
Return the type of the database.
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def health(self) -> dict:
"""
Return the health status of the database.
"""
raise NotImplementedError("Not implemented")
"""
Table operations
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
"""
Check if an index with given name exists
"""
raise NotImplementedError("Not implemented")
"""
CRUD operations
"""
@abstractmethod
def search(
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
) -> list[dict] | pl.DataFrame:
"""
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
@abstractmethod
def getTotal(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getChunkIds(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: List[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def getAggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""
raise NotImplementedError("Not implemented")

View File

@ -1,29 +1,29 @@
import re
import json
import time
import copy
import os
from typing import List, Dict
import elasticsearch
from elastic_transport import ConnectionTimeout
import copy
from elasticsearch import Elasticsearch
from elasticsearch_dsl import UpdateByQuery, Search, Index
from rag.settings import es_logger
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout
from rag.settings import doc_store_logger
from rag import settings
from rag.utils import singleton
from api.utils.file_utils import get_project_base_directory
import polars as pl
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from rag.nlp import is_english, rag_tokenizer
es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
@singleton
class ESConnection:
class ESConnection(DocStoreConnection):
def __init__(self):
self.info = {}
self.conn()
self.idxnm = settings.ES.get("index_name", "")
if not self.es.ping():
raise Exception("Can't connect to ES cluster")
def conn(self):
for _ in range(10):
try:
self.es = Elasticsearch(
@ -34,390 +34,317 @@ class ESConnection:
)
if self.es:
self.info = self.es.info()
es_logger.info("Connect to es.")
doc_store_logger.info("Connect to es.")
break
except Exception as e:
es_logger.error("Fail to connect to es: " + str(e))
doc_store_logger.error("Fail to connect to es: " + str(e))
time.sleep(1)
def version(self):
if not self.es.ping():
raise Exception("Can't connect to ES cluster")
v = self.info.get("version", {"number": "5.6"})
v = v["number"].split(".")[0]
return int(v) >= 7
if int(v) < 8:
raise Exception(f"ES version must be greater than or equal to 8, current version: {v}")
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
self.mapping = json.load(open(fp_mapping, "r"))
def health(self):
return dict(self.es.cluster.health())
"""
Database operations
"""
def dbType(self) -> str:
return "elasticsearch"
def upsert(self, df, idxnm=""):
res = []
for d in df:
id = d["id"]
del d["id"]
d = {"doc": d, "doc_as_upsert": "true"}
T = False
for _ in range(10):
try:
if not self.version():
r = self.es.update(
index=(
self.idxnm if not idxnm else idxnm),
body=d,
id=id,
doc_type="doc",
refresh=True,
retry_on_conflict=100)
else:
r = self.es.update(
index=(
self.idxnm if not idxnm else idxnm),
body=d,
id=id,
refresh=True,
retry_on_conflict=100)
es_logger.info("Successfully upsert: %s" % id)
T = True
break
except Exception as e:
es_logger.warning("Fail to index: " +
json.dumps(d, ensure_ascii=False) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
T = False
def health(self) -> dict:
return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
if not T:
res.append(d)
es_logger.error(
"Fail to index: " +
re.sub(
"[\r\n]",
"",
json.dumps(
d,
ensure_ascii=False)))
d["id"] = id
d["_index"] = self.idxnm
if not res:
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
return True
return False
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=indexName,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception as e:
doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e)))
def bulk(self, df, idx_nm=None):
ids, acts = {}, []
for d in df:
id = d["id"] if "id" in d else d["_id"]
ids[id] = copy.deepcopy(d)
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
if "id" in d:
del d["id"]
if "_id" in d:
del d["_id"]
acts.append(
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
acts.append({"doc": d, "doc_as_upsert": "true"})
def deleteIdx(self, indexName: str, knowledgebaseId: str):
try:
return self.es.indices.delete(indexName, allow_no_indices=True)
except Exception as e:
doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e)))
res = []
for _ in range(100):
try:
if elasticsearch.__version__[0] < 8:
r = self.es.bulk(
index=(
self.idxnm if not idx_nm else idx_nm),
body=acts,
refresh=False,
timeout="600s")
else:
r = self.es.bulk(index=(self.idxnm if not idx_nm else
idx_nm), operations=acts,
refresh=False, timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for it in r["items"]:
if "error" in it["update"]:
res.append(str(it["update"]["_id"]) +
":" + str(it["update"]["error"]))
return res
except Exception as e:
es_logger.warn("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
return res
def bulk4script(self, df):
ids, acts = {}, []
for d in df:
id = d["id"]
ids[id] = copy.deepcopy(d["raw"])
acts.append({"update": {"_id": id, "_index": self.idxnm}})
acts.append(d["script"])
es_logger.info("bulk upsert: %s" % id)
res = []
for _ in range(10):
try:
if not self.version():
r = self.es.bulk(
index=self.idxnm,
body=acts,
refresh=False,
timeout="600s",
doc_type="doc")
else:
r = self.es.bulk(
index=self.idxnm,
body=acts,
refresh=False,
timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for it in r["items"]:
if "error" in it["update"]:
res.append(str(it["update"]["_id"]))
return res
except Exception as e:
es_logger.warning("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
return res
def rm(self, d):
for _ in range(10):
try:
if not self.version():
r = self.es.delete(
index=self.idxnm,
id=d["id"],
doc_type="doc",
refresh=True)
else:
r = self.es.delete(
index=self.idxnm,
id=d["id"],
refresh=True,
doc_type="_doc")
es_logger.info("Remove %s" % d["id"])
return True
except Exception as e:
es_logger.warn("Fail to delete: " + str(d) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return True
self.conn()
es_logger.error("Fail to delete: " + str(d))
return False
def search(self, q, idxnms=None, src=False, timeout="2s"):
if not isinstance(q, dict):
q = Search().query(q).to_dict()
if isinstance(idxnms, str):
idxnms = idxnms.split(",")
for i in range(3):
try:
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
body=q,
timeout=timeout,
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=src)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
return res
except Exception as e:
es_logger.error(
"ES search exception: " +
str(e) +
"【Q】" +
str(q))
if str(e).find("Timeout") > 0:
continue
raise e
es_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
for i in range(3):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
return res
except ConnectionTimeout as e:
es_logger.error("Timeout【Q】" + sql)
continue
except Exception as e:
raise e
es_logger.error("ES search timeout for 3 times!")
raise ConnectionTimeout()
def get(self, doc_id, idxnm=None):
for i in range(3):
try:
res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
id=doc_id)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
return res
except Exception as e:
es_logger.error(
"ES get exception: " +
str(e) +
"【Q】" +
doc_id)
if str(e).find("Timeout") > 0:
continue
raise e
es_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def updateByQuery(self, q, d):
ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
scripts = ""
for k, v in d.items():
scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
ubq = ubq.script(source=scripts, params=d)
ubq = ubq.params(refresh=False)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
r = ubq.execute()
return True
except Exception as e:
es_logger.error("ES updateByQuery exception: " +
str(e) + "【Q】" + str(q.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
self.conn()
return False
def updateScriptByQuery(self, q, scripts, idxnm=None):
ubq = UpdateByQuery(
index=self.idxnm if not idxnm else idxnm).using(
self.es).query(q)
ubq = ubq.script(source=scripts)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
r = ubq.execute()
return True
except Exception as e:
es_logger.error("ES updateByQuery exception: " +
str(e) + "【Q】" + str(q.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
self.conn()
return False
def deleteByQuery(self, query, idxnm=""):
for i in range(3):
try:
r = self.es.delete_by_query(
index=idxnm if idxnm else self.idxnm,
refresh = True,
body=Search().query(query).to_dict())
return True
except Exception as e:
es_logger.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】" + str(query.to_dict()))
if str(e).find("NotFoundError") > 0: return True
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False
def update(self, id, script, routing=None):
for i in range(3):
try:
if not self.version():
r = self.es.update(
index=self.idxnm,
id=id,
body=json.dumps(
script,
ensure_ascii=False),
doc_type="doc",
routing=routing,
refresh=False)
else:
r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
routing=routing, refresh=False) # , doc_type="_doc")
return True
except Exception as e:
es_logger.error(
"ES update exception: " + str(e) + " id" + str(id) + ", version:" + str(self.version()) +
json.dumps(script, ensure_ascii=False))
if str(e).find("Timeout") > 0:
continue
return False
def indexExist(self, idxnm):
s = Index(idxnm if idxnm else self.idxnm, self.es)
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
s = Index(indexName, self.es)
for i in range(3):
try:
return s.exists()
except Exception as e:
es_logger.error("ES updateByQuery indexExist: " + str(e))
doc_store_logger.error("ES indexExist: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False
def docExist(self, docid, idxnm=None):
"""
CRUD operations
"""
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
assert "_id" not in condition
s = Search()
bqry = None
vector_similarity_weight = 0.5
for m in matchExprs:
if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = float(weights.split(",")[1])
for m in matchExprs:
if isinstance(m, MatchTextExpr):
minimum_should_match = "0%"
if "minimum_should_match" in m.extra_options:
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
bqry = Q("bool",
must=Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match = minimum_should_match,
boost=1),
boost = 1.0 - vector_similarity_weight,
)
if condition:
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
elif isinstance(m, MatchDenseExpr):
assert(bqry is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector = list(m.embedding_data),
filter = bqry.to_dict(),
similarity = similarity,
)
if matchExprs:
s.query = bqry
for field in highlightFields:
s = s.highlight(field)
if orderBy:
orders = list()
for field, order in orderBy.fields:
order = "asc" if order == 0 else "desc"
orders.append({field: {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}})
s = s.sort(*orders)
if limit > 0:
s = s[offset:limit]
q = s.to_dict()
doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q))
for i in range(3):
try:
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
id=docid)
res = self.es.search(index=indexNames,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
doc_store_logger.info("ESConnection.search res: " + str(res))
return res
except Exception as e:
es_logger.error("ES Doc Exist: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
doc_store_logger.error(
"ES search exception: " +
str(e) +
"\n[Q]: " +
str(q))
if str(e).find("Timeout") > 0:
continue
raise e
doc_store_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
for i in range(3):
try:
res = self.es.get(index=(indexName),
id=chunkId, source=True,)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
if not res.get("found"):
return None
chunk = res["_source"]
chunk["id"] = chunkId
return chunk
except Exception as e:
doc_store_logger.error(
"ES get exception: " +
str(e) +
"[Q]: " +
chunkId)
if str(e).find("Timeout") > 0:
continue
raise e
doc_store_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
meta_id = d_copy["id"]
del d_copy["id"]
operations.append(
{"index": {"_index": indexName, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(100):
try:
r = self.es.bulk(index=(indexName), operations=operations,
refresh=False, timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except Exception as e:
doc_store_logger.warning("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
doc = copy.deepcopy(newValue)
del doc['id']
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunkId = condition["id"]
for i in range(3):
try:
self.es.update(index=indexName, id=chunkId, doc=doc)
return True
except Exception as e:
doc_store_logger.error(
"ES update exception: " + str(e) + " id:" + str(id) +
json.dumps(newValue, ensure_ascii=False))
if str(e).find("Timeout") > 0:
continue
else:
# update unspecific maybe-multiple documents
bqry = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
for k, v in newValue.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, str):
scripts.append(f"ctx._source.{k} = '{v}'")
elif isinstance(v, int):
scripts.append(f"ctx._source.{k} = {v}")
else:
raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
ubq = ubq.script(source="; ".join(scripts))
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
_ = ubq.execute()
return True
except Exception as e:
doc_store_logger.error("ES update exception: " +
str(e) + "[Q]:" + str(bqry.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False
def createIdx(self, idxnm, mapping):
try:
if elasticsearch.__version__[0] < 8:
return self.es.indices.create(idxnm, body=mapping)
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=idxnm,
settings=mapping["settings"],
mappings=mapping["mappings"])
except Exception as e:
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
qry = None
assert "_id" not in condition
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
chunk_ids = [chunk_ids]
qry = Q("ids", values=chunk_ids)
else:
qry = Q("bool")
for k, v in condition.items():
if isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
for _ in range(10):
try:
res = self.es.delete_by_query(
index=indexName,
body = Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except Exception as e:
doc_store_logger.warning("Fail to delete: " + str(filter) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
def deleteIdx(self, idxnm):
try:
return self.es.indices.delete(idxnm, allow_no_indices=True)
except Exception as e:
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
"""
Helper functions for search result
"""
def getTotal(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getDocIds(self, res):
def getChunkIds(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def getSource(self, res):
def __getSource(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
@ -425,40 +352,89 @@ class ESConnection:
rr.append(d["_source"])
return rr
def scrollIter(self, pagesize=100, scroll_time='2m', q={
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
for _ in range(100):
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
res_fields = {}
if not fields:
return {}
for d in self.__getSource(res):
m = {n: d.get(n) for n in fields if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, list):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(m[n])
# if n.find("tks") > 0:
# m[n] = rmSpace(m[n])
if m:
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: List[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split(" ")):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def getAggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
bkts = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
doc_store_logger.info(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
doc_store_logger.info(f"ESConnection.sql to es: {sql}")
for i in range(3):
try:
page = self.es.search(
index=self.idxnm,
scroll=scroll_time,
size=pagesize,
body=q,
_source=None
)
break
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
return res
except ConnectionTimeout:
doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql)
continue
except Exception as e:
es_logger.error("ES scrolling fail. " + str(e))
time.sleep(3)
sid = page['_scroll_id']
scroll_size = page['hits']['total']["value"]
es_logger.info("[TOTAL]%d" % scroll_size)
# Start scrolling
while scroll_size > 0:
yield page["hits"]["hits"]
for _ in range(100):
try:
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
break
except Exception as e:
es_logger.error("ES scrolling fail. " + str(e))
time.sleep(3)
# Update the scroll ID
sid = page['_scroll_id']
# Get the number of results that we returned in the last scroll
scroll_size = len(page['hits']['hits'])
ELASTICSEARCH = ESConnection()
doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e))
return None
doc_store_logger.error("ESConnection.sql timeout for 3 times!")
return None

436
rag/utils/infinity_conn.py Normal file
View File

@ -0,0 +1,436 @@
import os
import re
import json
from typing import List, Dict
import infinity
from infinity.common import ConflictType, InfinityException
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from rag import settings
from rag.settings import doc_store_logger
from rag.utils import singleton
import polars as pl
from polars.series.series import Series
from api.utils.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import (
DocStoreConnection,
MatchExpr,
MatchTextExpr,
MatchDenseExpr,
FusionExpr,
OrderByExpr,
)
def equivalent_condition_to_str(condition: dict) -> str:
assert "_id" not in condition
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond)
@singleton
class InfinityConnection(DocStoreConnection):
def __init__(self):
self.dbName = settings.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = ConnectionPool(infinity_uri)
doc_store_logger.info(f"Connected to infinity {infinity_uri}.")
"""
Database operations
"""
def dbType(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
"""
inf_conn = self.connPool.get_conn()
res = infinity.show_current_node()
self.connPool.release_conn(inf_conn)
color = "green" if res.error_code == 0 else "red"
res2 = {
"type": "infinity",
"status": f"{res.role} {color}",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(
get_project_base_directory(), "conf", "infinity_mapping.json"
)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vectorSize}_vec"
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
text_suffix = ["_tks", "_ltks", "_kwd"]
for field_name, field_info in schema.items():
if field_info["type"] != "varchar":
continue
for suffix in text_suffix:
if field_name.endswith(suffix):
inf_table.create_index(
f"text_idx_{field_name}",
IndexInfo(
field_name, IndexType.FullText, {"ANALYZER": "standard"}
),
ConflictType.Ignore,
)
break
self.connPool.release_conn(inf_conn)
doc_store_logger.info(
f"INFINITY created table {table_name}, vector size {vectorSize}"
)
def deleteIdx(self, indexName: str, knowledgebaseId: str):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
doc_store_logger.info(f"INFINITY dropped table {table_name}")
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
table_name = f"{indexName}_{knowledgebaseId}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
doc_store_logger.error("INFINITY indexExist: " + str(e))
return False
"""
CRUD operations
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
) -> list[dict] | pl.DataFrame:
"""
TODO: Infinity doesn't provide highlight
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
if "id" not in selectFields:
selectFields.append("id")
# Prepare expressions common to all tables
filter_cond = ""
filter_fulltext = ""
if condition:
filter_cond = equivalent_condition_to_str(condition)
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
fields = ",".join(matchExpr.fields)
filter_fulltext = (
f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
)
if len(filter_cond) != 0:
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
# doc_store_logger.info(f"filter_fulltext: {filter_fulltext}")
minimum_should_match = "0%"
if "minimum_should_match" in matchExpr.extra_options:
minimum_should_match = (
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%"
)
matchExpr.extra_options.update(
{"minimum_should_match": minimum_should_match}
)
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
elif isinstance(matchExpr, MatchDenseExpr):
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
if orderBy.fields:
order_by_expr_list = list()
for order_field in orderBy.fields:
order_by_expr_list.append((order_field[0], order_field[1] == 0))
# Scatter search tables and gather the results
for indexName in indexNames:
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
continue
table_list.append(table_name)
builder = table_instance.output(selectFields)
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
if orderBy.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
kb_res = builder.to_pl()
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = pl.concat(df_list)
doc_store_logger.info("INFINITY search tables: " + str(table_list))
return res
def get(
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(knowledgebaseIds, list)
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
table_instance = db_instance.get_table(table_name)
kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = pl.concat(df_list)
res_fields = self.getFields(res, res.columns)
return res_fields.get(chunkId, None)
def insert(
self, documents: list[dict], indexName: str, knowledgebaseId: str
) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
# src/common/status.cppm, kTableNotExist = 3022
if e.error_code != 3022:
raise
vector_size = 0
patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
for k in documents[0].keys():
m = patt.match(k)
if m:
vector_size = int(m.group("vector_size"))
break
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.createIdx(indexName, knowledgebaseId, vector_size)
table_instance = db_instance.get_table(table_name)
for d in documents:
assert "_id" not in d
assert "id" in d
for k, v in d.items():
if k.endswith("_kwd") and isinstance(v, list):
d[k] = " ".join(v)
ids = [f"'{d["id"]}'" for d in documents]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter)
# for doc in documents:
# doc_store_logger.info(f"insert position_list: {doc['position_list']}")
# doc_store_logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(documents)
self.connPool.release_conn(inf_conn)
doc_store_logger.info(f"inserted into {table_name} {str_ids}.")
return []
def update(
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
) -> bool:
# if 'position_list' in newValue:
# doc_store_logger.info(f"update position_list: {newValue['position_list']}")
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_instance = db_instance.get_table(table_name)
filter = equivalent_condition_to_str(condition)
for k, v in newValue.items():
if k.endswith("_kwd") and isinstance(v, list):
newValue[k] = " ".join(v)
table_instance.update(filter, newValue)
self.connPool.release_conn(inf_conn)
return True
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
filter = equivalent_condition_to_str(condition)
try:
table_instance = db_instance.get_table(table_name)
except Exception:
doc_store_logger.warning(
f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
)
return 0
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def getTotal(self, res):
return len(res)
def getChunkIds(self, res):
return list(res["id"])
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
res_fields = {}
if not fields:
return {}
num_rows = len(res)
column_id = res["id"]
for i in range(num_rows):
id = column_id[i]
m = {"id": id}
for fieldnm in fields:
if fieldnm not in res:
m[fieldnm] = None
continue
v = res[fieldnm][i]
if isinstance(v, Series):
v = list(v)
elif fieldnm == "important_kwd":
assert isinstance(v, str)
v = v.split(" ")
else:
if not isinstance(v, str):
v = str(v)
# if fieldnm.endswith("_tks"):
# v = rmSpace(v)
m[fieldnm] = v
res_fields[id] = m
return res_fields
def getHighlight(self, res, keywords: List[str], fieldnm: str):
ans = {}
num_rows = len(res)
column_id = res["id"]
for i in range(num_rows):
id = column_id[i]
txt = res[fieldnm][i]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
% re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
):
continue
txts.append(t)
ans[id] = "...".join(txts)
return ans
def getAggregation(self, res, fieldnm: str):
"""
TODO: Infinity doesn't provide aggregation
"""
return list()
"""
SQL
"""
def sql(sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -50,8 +50,8 @@ class Document(Base):
return res.content
def list_chunks(self,page=1, page_size=30, keywords="", id:str=None):
data={"keywords": keywords,"page":page,"page_size":page_size,"id":id}
def list_chunks(self,page=1, page_size=30, keywords=""):
data={"keywords": keywords,"page":page,"page_size":page_size}
res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
res = res.json()
if res.get("code") == 0:

View File

@ -126,6 +126,7 @@ def test_delete_chunk_with_success(get_api_key_fixture):
docs = ds.upload_documents(documents)
doc = docs[0]
chunk = doc.add_chunk(content="This is a chunk addition test")
sleep(5)
doc.delete_chunks([chunk.id])
@ -146,6 +147,8 @@ def test_update_chunk_content(get_api_key_fixture):
docs = ds.upload_documents(documents)
doc = docs[0]
chunk = doc.add_chunk(content="This is a chunk addition test")
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
sleep(3)
chunk.update({"content":"This is a updated content"})
def test_update_chunk_available(get_api_key_fixture):
@ -165,7 +168,9 @@ def test_update_chunk_available(get_api_key_fixture):
docs = ds.upload_documents(documents)
doc = docs[0]
chunk = doc.add_chunk(content="This is a chunk addition test")
chunk.update({"available":False})
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
sleep(3)
chunk.update({"available":0})
def test_retrieve_chunks(get_api_key_fixture):