mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 05:05:54 +08:00
Integration with Infinity (#2894)
### What problem does this PR solve? Integration with Infinity - Replaced ELASTICSEARCH with dataStoreConn - Renamed deleteByQuery with delete - Renamed bulk to upsertBulk - getHighlight, getAggregation - Fix KGSearch.search - Moved Dealer.sql_retrieval to es_conn.py ### Type of change - [x] Refactoring
This commit is contained in:
parent
00b6000b76
commit
f4c52371ab
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -78,7 +78,7 @@ jobs:
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
|
||||
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest --tb=short t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
|
||||
|
||||
- name: Stop ragflow:dev
|
||||
if: always() # always run this step even if previous steps failed
|
||||
|
@ -285,7 +285,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
cd ragflow/
|
||||
export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
~/.local/bin/poetry install --sync --no-root # install RAGFlow dependent python modules
|
||||
~/.local/bin/poetry install --sync --no-root --with=full # install RAGFlow dependent python modules
|
||||
```
|
||||
|
||||
3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
|
||||
@ -295,7 +295,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
||||
|
||||
Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
||||
```
|
||||
127.0.0.1 es01 mysql minio redis
|
||||
127.0.0.1 es01 infinity mysql minio redis
|
||||
```
|
||||
In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
||||
|
||||
|
@ -250,7 +250,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
||||
|
||||
`/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
|
||||
```
|
||||
127.0.0.1 es01 mysql minio redis
|
||||
127.0.0.1 es01 infinity mysql minio redis
|
||||
```
|
||||
**docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).
|
||||
|
||||
|
@ -254,7 +254,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
||||
|
||||
`/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
|
||||
```
|
||||
127.0.0.1 es01 mysql minio redis
|
||||
127.0.0.1 es01 infinity mysql minio redis
|
||||
```
|
||||
**docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).
|
||||
|
||||
|
@ -252,7 +252,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
||||
|
||||
在 `/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`:
|
||||
```
|
||||
127.0.0.1 es01 mysql minio redis
|
||||
127.0.0.1 es01 infinity mysql minio redis
|
||||
```
|
||||
在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`,es 端口更新为 `1200`。
|
||||
|
||||
|
@ -529,13 +529,14 @@ def list_chunks():
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_name or doc_id"
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id)
|
||||
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
"doc_name": res_item["docnm_kwd"],
|
||||
"img_id": res_item["img_id"]
|
||||
"image_id": res_item["img_id"]
|
||||
} for res_item in res
|
||||
]
|
||||
|
||||
|
@ -18,12 +18,10 @@ import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from elasticsearch_dsl import Q
|
||||
|
||||
from api.db.services.dialog_service import keyword_extraction
|
||||
from rag.app.qa import rmPrefix, beAdoc
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from rag.utils import rmSpace
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -31,12 +29,11 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.settings import RetCode, retrievaler, kg_retrievaler
|
||||
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
|
||||
from api.utils.api_utils import get_json_result
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
@ -53,12 +50,13 @@ def list_chunk():
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
query = {
|
||||
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
||||
}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@ -69,16 +67,12 @@ def list_chunk():
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
"img_id": sres.field[id].get("img_id", ""),
|
||||
"image_id": sres.field[id].get("img_id", ""),
|
||||
"available_int": sres.field[id].get("available_int", 1),
|
||||
"positions": sres.field[id].get("position_int", "").split("\t")
|
||||
"positions": json.loads(sres.field[id].get("position_list", "[]")),
|
||||
}
|
||||
if len(d["positions"]) % 5 == 0:
|
||||
poss = []
|
||||
for i in range(0, len(d["positions"]), 5):
|
||||
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
||||
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
||||
d["positions"] = poss
|
||||
assert isinstance(d["positions"], list)
|
||||
assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
||||
res["chunks"].append(d)
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
@ -96,22 +90,20 @@ def get():
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
res = ELASTICSEARCH.get(
|
||||
chunk_id, search.index_name(
|
||||
tenants[0].tenant_id))
|
||||
if not res.get("found"):
|
||||
tenant_id = tenants[0].tenant_id
|
||||
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response("Chunk not found")
|
||||
id = res["_id"]
|
||||
res = res["_source"]
|
||||
res["chunk_id"] = id
|
||||
k = []
|
||||
for n in res.keys():
|
||||
for n in chunk.keys():
|
||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||
k.append(n)
|
||||
for n in k:
|
||||
del res[n]
|
||||
del chunk[n]
|
||||
|
||||
return get_json_result(data=res)
|
||||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
if str(e).find("NotFoundError") >= 0:
|
||||
return get_json_result(data=False, message='Chunk not found!',
|
||||
@ -162,7 +154,7 @@ def set():
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
||||
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -174,11 +166,11 @@ def set():
|
||||
def switch():
|
||||
req = request.json
|
||||
try:
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
|
||||
search.index_name(tenant_id)):
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
||||
search.index_name(doc.tenant_id), doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -191,12 +183,11 @@ def switch():
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
if not ELASTICSEARCH.deleteByQuery(
|
||||
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
@ -239,7 +230,7 @@ def create():
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
||||
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
@ -256,8 +247,9 @@ def retrieval_test():
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("size", 30))
|
||||
question = req["question"]
|
||||
kb_id = req["kb_id"]
|
||||
if isinstance(kb_id, str): kb_id = [kb_id]
|
||||
kb_ids = req["kb_id"]
|
||||
if isinstance(kb_ids, str):
|
||||
kb_ids = [kb_ids]
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
@ -265,17 +257,17 @@ def retrieval_test():
|
||||
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for kid in kb_id:
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
tenant_id=tenant.tenant_id, id=kid):
|
||||
tenant_id=tenant.tenant_id, id=kb_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
return get_data_error_result(message="Knowledgebase not found!")
|
||||
|
||||
@ -290,7 +282,7 @@ def retrieval_test():
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
|
||||
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
||||
for c in ranks["chunks"]:
|
||||
@ -309,12 +301,16 @@ def retrieval_test():
|
||||
@login_required
|
||||
def knowledge_graph():
|
||||
doc_id = request.args["doc_id"]
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
req = {
|
||||
"doc_ids":[doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
sres = retrievaler.search(req, search.index_name(tenant_id))
|
||||
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids, doc.kb_id)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
|
@ -17,7 +17,6 @@ import pathlib
|
||||
import re
|
||||
|
||||
import flask
|
||||
from elasticsearch_dsl import Q
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
@ -27,14 +26,13 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from rag.nlp import search
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType, TaskStatus, ParserType, FileSource
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.settings import RetCode
|
||||
from api.settings import RetCode, docStoreConn
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
@ -275,18 +273,8 @@ def change_status():
|
||||
return get_data_error_result(
|
||||
message="Database error (Document update)!")
|
||||
|
||||
if str(req["status"]) == "0":
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
||||
scripts="ctx._source.available_int=0;",
|
||||
idxnm=search.index_name(
|
||||
kb.tenant_id)
|
||||
)
|
||||
else:
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
||||
scripts="ctx._source.available_int=1;",
|
||||
idxnm=search.index_name(
|
||||
kb.tenant_id)
|
||||
)
|
||||
status = int(req["status"])
|
||||
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -365,8 +353,11 @@ def run():
|
||||
tenant_id = DocumentService.get_tenant_id(id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
@ -490,8 +481,8 @@ def change_parser():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
|
@ -28,6 +28,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
from api.settings import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.settings import docStoreConn
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post'])
|
||||
@ -166,6 +168,9 @@ def rm():
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
@ -30,7 +30,6 @@ from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.utils.api_utils import server_error_response
|
||||
from api.utils.api_utils import get_result, get_error_data_result
|
||||
from io import BytesIO
|
||||
from elasticsearch_dsl import Q
|
||||
from flask import request, send_file
|
||||
from api.db import FileSource, TaskStatus, FileType
|
||||
from api.db.db_models import File
|
||||
@ -42,7 +41,7 @@ from api.settings import RetCode, retrievaler
|
||||
from api.utils.api_utils import construct_json_result, get_parser_config
|
||||
from rag.nlp import search
|
||||
from rag.utils import rmSpace
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from api.settings import docStoreConn
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
import os
|
||||
|
||||
@ -293,9 +292,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
)
|
||||
if not e:
|
||||
return get_error_data_result(message="Document not found!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)
|
||||
)
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
|
||||
return get_result()
|
||||
|
||||
@ -647,9 +644,7 @@ def parse(tenant_id, dataset_id):
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
DocumentService.update_by_id(id, info)
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
|
||||
)
|
||||
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
doc = doc.to_dict()
|
||||
@ -713,9 +708,7 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
)
|
||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
|
||||
)
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@ -812,7 +805,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
"question": question,
|
||||
"sort": True,
|
||||
}
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
||||
key_mapping = {
|
||||
"chunk_num": "chunk_count",
|
||||
"kb_id": "dataset_id",
|
||||
@ -833,51 +825,56 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
renamed_doc[new_key] = value
|
||||
if key == "run":
|
||||
renamed_doc["run"] = run_mapping.get(str(value))
|
||||
res = {"total": sres.total, "chunks": [], "doc": renamed_doc}
|
||||
origin_chunks = []
|
||||
sign = 0
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_with_weight": (
|
||||
rmSpace(sres.highlight[id])
|
||||
if question and id in sres.highlight
|
||||
else sres.field[id].get("content_with_weight", "")
|
||||
),
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
"img_id": sres.field[id].get("img_id", ""),
|
||||
"available_int": sres.field[id].get("available_int", 1),
|
||||
"positions": sres.field[id].get("position_int", "").split("\t"),
|
||||
}
|
||||
if len(d["positions"]) % 5 == 0:
|
||||
poss = []
|
||||
for i in range(0, len(d["positions"]), 5):
|
||||
poss.append(
|
||||
[
|
||||
float(d["positions"][i]),
|
||||
float(d["positions"][i + 1]),
|
||||
float(d["positions"][i + 2]),
|
||||
float(d["positions"][i + 3]),
|
||||
float(d["positions"][i + 4]),
|
||||
]
|
||||
)
|
||||
d["positions"] = poss
|
||||
|
||||
origin_chunks.append(d)
|
||||
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
||||
origin_chunks = []
|
||||
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
res["total"] = sres.total
|
||||
sign = 0
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
"id": id,
|
||||
"content_with_weight": (
|
||||
rmSpace(sres.highlight[id])
|
||||
if question and id in sres.highlight
|
||||
else sres.field[id].get("content_with_weight", "")
|
||||
),
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
"img_id": sres.field[id].get("img_id", ""),
|
||||
"available_int": sres.field[id].get("available_int", 1),
|
||||
"positions": sres.field[id].get("position_int", "").split("\t"),
|
||||
}
|
||||
if len(d["positions"]) % 5 == 0:
|
||||
poss = []
|
||||
for i in range(0, len(d["positions"]), 5):
|
||||
poss.append(
|
||||
[
|
||||
float(d["positions"][i]),
|
||||
float(d["positions"][i + 1]),
|
||||
float(d["positions"][i + 2]),
|
||||
float(d["positions"][i + 3]),
|
||||
float(d["positions"][i + 4]),
|
||||
]
|
||||
)
|
||||
d["positions"] = poss
|
||||
|
||||
origin_chunks.append(d)
|
||||
if req.get("id"):
|
||||
if req.get("id") == id:
|
||||
origin_chunks.clear()
|
||||
origin_chunks.append(d)
|
||||
sign = 1
|
||||
break
|
||||
if req.get("id"):
|
||||
if req.get("id") == id:
|
||||
origin_chunks.clear()
|
||||
origin_chunks.append(d)
|
||||
sign = 1
|
||||
break
|
||||
if req.get("id"):
|
||||
if sign == 0:
|
||||
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
|
||||
if sign == 0:
|
||||
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
|
||||
|
||||
for chunk in origin_chunks:
|
||||
key_mapping = {
|
||||
"chunk_id": "id",
|
||||
"id": "id",
|
||||
"content_with_weight": "content",
|
||||
"doc_id": "document_id",
|
||||
"important_kwd": "important_keywords",
|
||||
@ -996,9 +993,9 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
)
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
d["kb_id"] = [doc.kb_id]
|
||||
d["kb_id"] = dataset_id
|
||||
d["docnm_kwd"] = doc.name
|
||||
d["doc_id"] = doc.id
|
||||
d["doc_id"] = document_id
|
||||
embd_id = DocumentService.get_embd_id(document_id)
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
tenant_id, LLMType.EMBEDDING.value, embd_id
|
||||
@ -1006,14 +1003,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
v, c = embd_mdl.encode([doc.name, req["content"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
||||
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||
d["chunk_id"] = chunk_id
|
||||
d["kb_id"] = doc.kb_id
|
||||
# rename keys
|
||||
key_mapping = {
|
||||
"chunk_id": "id",
|
||||
"id": "id",
|
||||
"content_with_weight": "content",
|
||||
"doc_id": "document_id",
|
||||
"important_kwd": "important_keywords",
|
||||
@ -1079,36 +1074,16 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
"""
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
|
||||
if not doc:
|
||||
return get_error_data_result(
|
||||
message=f"You don't own the document {document_id}."
|
||||
)
|
||||
doc = doc[0]
|
||||
req = request.json
|
||||
if not req.get("chunk_ids"):
|
||||
return get_error_data_result("`chunk_ids` is required")
|
||||
query = {"doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True}
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
||||
if not req:
|
||||
chunk_ids = None
|
||||
else:
|
||||
chunk_ids = req.get("chunk_ids")
|
||||
if not chunk_ids:
|
||||
chunk_list = sres.ids
|
||||
else:
|
||||
chunk_list = chunk_ids
|
||||
for chunk_id in chunk_list:
|
||||
if chunk_id not in sres.ids:
|
||||
return get_error_data_result(f"Chunk {chunk_id} not found")
|
||||
if not ELASTICSEARCH.deleteByQuery(
|
||||
Q("ids", values=chunk_list), search.index_name(tenant_id)
|
||||
):
|
||||
return get_error_data_result(message="Index updating failure")
|
||||
deleted_chunk_ids = chunk_list
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||
return get_result()
|
||||
condition = {"doc_id": document_id}
|
||||
if "chunk_ids" in req:
|
||||
condition["id"] = req["chunk_ids"]
|
||||
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||
if chunk_number != 0:
|
||||
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
||||
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
||||
return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(req["chunk_ids"])}")
|
||||
return get_result(message=f"deleted {chunk_number} chunks")
|
||||
|
||||
|
||||
@manager.route(
|
||||
@ -1168,9 +1143,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
try:
|
||||
res = ELASTICSEARCH.get(chunk_id, search.index_name(tenant_id))
|
||||
except Exception:
|
||||
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||
if chunk is None:
|
||||
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||
@ -1180,19 +1154,12 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
message=f"You don't own the document {document_id}."
|
||||
)
|
||||
doc = doc[0]
|
||||
query = {
|
||||
"doc_ids": [document_id],
|
||||
"page": 1,
|
||||
"size": 1024,
|
||||
"question": "",
|
||||
"sort": True,
|
||||
}
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
||||
if chunk_id not in sres.ids:
|
||||
return get_error_data_result(f"You don't own the chunk {chunk_id}")
|
||||
req = request.json
|
||||
content = res["_source"].get("content_with_weight")
|
||||
d = {"id": chunk_id, "content_with_weight": req.get("content", content)}
|
||||
if "content" in req:
|
||||
content = req["content"]
|
||||
else:
|
||||
content = chunk.get("content_with_weight", "")
|
||||
d = {"id": chunk_id, "content_with_weight": content}
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if "important_keywords" in req:
|
||||
@ -1220,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
||||
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ from api.utils.api_utils import (
|
||||
generate_confirmation_token,
|
||||
)
|
||||
from api.versions import get_rag_version
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from api.settings import docStoreConn
|
||||
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
||||
from timeit import default_timer as timer
|
||||
|
||||
@ -98,10 +98,11 @@ def status():
|
||||
res = {}
|
||||
st = timer()
|
||||
try:
|
||||
res["es"] = ELASTICSEARCH.health()
|
||||
res["es"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||
res["doc_store"] = docStoreConn.health()
|
||||
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||
except Exception as e:
|
||||
res["es"] = {
|
||||
res["doc_store"] = {
|
||||
"type": "unknown",
|
||||
"status": "red",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
"error": str(e),
|
||||
|
@ -470,7 +470,7 @@ class User(DataBaseModel, UserMixin):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
|
||||
@ -525,7 +525,7 @@ class Tenant(DataBaseModel):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
@ -542,7 +542,7 @@ class UserTenant(DataBaseModel):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
@ -559,7 +559,7 @@ class InvitationCode(DataBaseModel):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
@ -582,7 +582,7 @@ class LLMFactories(DataBaseModel):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
@ -616,7 +616,7 @@ class LLM(DataBaseModel):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
@ -703,7 +703,7 @@ class Knowledgebase(DataBaseModel):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
@ -767,7 +767,7 @@ class Document(DataBaseModel):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
@ -904,7 +904,7 @@ class Dialog(DataBaseModel):
|
||||
status = CharField(
|
||||
max_length=1,
|
||||
null=True,
|
||||
help_text="is it validate(0: wasted,1: validate)",
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
|
||||
@ -987,7 +987,7 @@ def migrate_db():
|
||||
help_text="where dose this document come from",
|
||||
index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
@ -996,7 +996,7 @@ def migrate_db():
|
||||
help_text="default rerank model ID"))
|
||||
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
@ -1004,59 +1004,59 @@ def migrate_db():
|
||||
help_text="default rerank model ID"))
|
||||
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
||||
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.alter_column_type('tenant_llm', 'api_key',
|
||||
CharField(max_length=1024, null=True, help_text="API KEY", index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('api_token', 'source',
|
||||
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column("tenant","tts_id",
|
||||
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('api_4_conversation', 'source',
|
||||
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
|
||||
DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('task', 'retry_count', IntegerField(default=0))
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.alter_column_type('api_token', 'dialog_id',
|
||||
CharField(max_length=32, null=True, index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
#
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import traceback
|
||||
@ -24,16 +23,13 @@ from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
from elasticsearch_dsl import Q
|
||||
from peewee import fn
|
||||
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.settings import stat_logger
|
||||
from api.settings import stat_logger, docStoreConn
|
||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from graphrag.mind_map_extractor import MindMapExtractor
|
||||
from rag.settings import SVR_QUEUE_NAME
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
|
||||
@ -112,8 +108,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def remove_document(cls, doc, tenant_id):
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
cls.clear_chunk_num(doc.id)
|
||||
return cls.delete_by_id(doc.id)
|
||||
|
||||
@ -225,6 +220,15 @@ class DocumentService(CommonService):
|
||||
return
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_knowledgebase_id(cls, doc_id):
|
||||
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return docs[0]["kb_id"]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_tenant_id_by_name(cls, name):
|
||||
@ -438,11 +442,6 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
if not ELASTICSEARCH.indexExist(idxnm):
|
||||
ELASTICSEARCH.createIdx(idxnm, json.load(
|
||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||
|
||||
err, files = FileService.upload_document(kb, file_objs, user_id)
|
||||
@ -486,7 +485,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
md5 = hashlib.md5()
|
||||
md5.update((ck["content_with_weight"] +
|
||||
str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["id"] = md5.hexdigest()
|
||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
if not d.get("image"):
|
||||
@ -499,8 +498,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
else:
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
|
||||
STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(kb.id, d["_id"])
|
||||
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(kb.id, d["id"])
|
||||
del d["image"]
|
||||
docs.append(d)
|
||||
|
||||
@ -520,6 +519,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
token_counts[doc_id] += c
|
||||
return vects
|
||||
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
try_create_idx = True
|
||||
|
||||
_, tenant = TenantService.get_by_id(kb.tenant_id)
|
||||
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||
for doc_id in docids:
|
||||
@ -550,7 +552,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
v = vects[i]
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
|
||||
if try_create_idx:
|
||||
if not docStoreConn.indexExist(idxnm, kb_id):
|
||||
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
try_create_idx = False
|
||||
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
@ -66,6 +66,16 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
return list(kbs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_kb_ids(cls, tenant_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
]
|
||||
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||
kb_ids = [kb["id"] for kb in kbs]
|
||||
return kb_ids
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_detail(cls, kb_id):
|
||||
|
@ -18,6 +18,8 @@ from datetime import date
|
||||
from enum import IntEnum, Enum
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from api.utils.log_utils import LoggerFactory, getLogger
|
||||
import rag.utils.es_conn
|
||||
import rag.utils.infinity_conn
|
||||
|
||||
# Logger
|
||||
LoggerFactory.set_directory(
|
||||
@ -33,7 +35,7 @@ access_logger = getLogger("access")
|
||||
database_logger = getLogger("database")
|
||||
chat_logger = getLogger("chat")
|
||||
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
import rag.utils
|
||||
from rag.nlp import search
|
||||
from graphrag import search as kg_search
|
||||
from api.utils import get_base_config, decrypt_database_config
|
||||
@ -206,8 +208,12 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
|
||||
PRIVILEGE_COMMAND_WHITELIST = []
|
||||
CHECK_NODES_IDENTITY = False
|
||||
|
||||
retrievaler = search.Dealer(ELASTICSEARCH)
|
||||
kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
|
||||
if 'username' in get_base_config("es", {}):
|
||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
else:
|
||||
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
retrievaler = search.Dealer(docStoreConn)
|
||||
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
||||
|
||||
|
||||
class CustomEnum(Enum):
|
||||
|
@ -126,10 +126,6 @@ def server_error_response(e):
|
||||
if len(e.args) > 1:
|
||||
return get_json_result(
|
||||
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||
if repr(e).find("index_not_found_exception") >= 0:
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR,
|
||||
message="No chunk found, please upload file and parse it.")
|
||||
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
|
||||
|
||||
@ -270,10 +266,6 @@ def construct_error_response(e):
|
||||
pass
|
||||
if len(e.args) > 1:
|
||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||
if repr(e).find("index_not_found_exception") >= 0:
|
||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
|
||||
message="No chunk found, please upload file and parse it.")
|
||||
|
||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
|
||||
|
||||
@ -295,7 +287,7 @@ def token_required(func):
|
||||
return decorated_function
|
||||
|
||||
|
||||
def get_result(code=RetCode.SUCCESS, message='error', data=None):
|
||||
def get_result(code=RetCode.SUCCESS, message="", data=None):
|
||||
if code == 0:
|
||||
if data is not None:
|
||||
response = {"code": code, "data": data}
|
||||
|
26
conf/infinity_mapping.json
Normal file
26
conf/infinity_mapping.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"id": {"type": "varchar", "default": ""},
|
||||
"doc_id": {"type": "varchar", "default": ""},
|
||||
"kb_id": {"type": "varchar", "default": ""},
|
||||
"create_time": {"type": "varchar", "default": ""},
|
||||
"create_timestamp_flt": {"type": "float", "default": 0.0},
|
||||
"img_id": {"type": "varchar", "default": ""},
|
||||
"docnm_kwd": {"type": "varchar", "default": ""},
|
||||
"title_tks": {"type": "varchar", "default": ""},
|
||||
"title_sm_tks": {"type": "varchar", "default": ""},
|
||||
"name_kwd": {"type": "varchar", "default": ""},
|
||||
"important_kwd": {"type": "varchar", "default": ""},
|
||||
"important_tks": {"type": "varchar", "default": ""},
|
||||
"content_with_weight": {"type": "varchar", "default": ""},
|
||||
"content_ltks": {"type": "varchar", "default": ""},
|
||||
"content_sm_ltks": {"type": "varchar", "default": ""},
|
||||
"page_num_list": {"type": "varchar", "default": ""},
|
||||
"top_list": {"type": "varchar", "default": ""},
|
||||
"position_list": {"type": "varchar", "default": ""},
|
||||
"weight_int": {"type": "integer", "default": 0},
|
||||
"weight_flt": {"type": "float", "default": 0.0},
|
||||
"rank_int": {"type": "integer", "default": 0},
|
||||
"available_int": {"type": "integer", "default": 1},
|
||||
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
|
||||
"entities_kwd": {"type": "varchar", "default": ""}
|
||||
}
|
@ -1,200 +1,203 @@
|
||||
{
|
||||
{
|
||||
"settings": {
|
||||
"index": {
|
||||
"number_of_shards": 2,
|
||||
"number_of_replicas": 0,
|
||||
"refresh_interval" : "1000ms"
|
||||
"refresh_interval": "1000ms"
|
||||
},
|
||||
"similarity": {
|
||||
"scripted_sim": {
|
||||
"type": "scripted",
|
||||
"script": {
|
||||
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
|
||||
}
|
||||
"scripted_sim": {
|
||||
"type": "scripted",
|
||||
"script": {
|
||||
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"lat_lon": {"type": "geo_point", "store":"true"}
|
||||
},
|
||||
"date_detection": "true",
|
||||
"dynamic_templates": [
|
||||
{
|
||||
"int": {
|
||||
"match": "*_int",
|
||||
"mapping": {
|
||||
"type": "integer",
|
||||
"store": "true"
|
||||
}
|
||||
"properties": {
|
||||
"lat_lon": {
|
||||
"type": "geo_point",
|
||||
"store": "true"
|
||||
}
|
||||
},
|
||||
"date_detection": "true",
|
||||
"dynamic_templates": [
|
||||
{
|
||||
"int": {
|
||||
"match": "*_int",
|
||||
"mapping": {
|
||||
"type": "integer",
|
||||
"store": "true"
|
||||
}
|
||||
},
|
||||
{
|
||||
"ulong": {
|
||||
"match": "*_ulong",
|
||||
"mapping": {
|
||||
"type": "unsigned_long",
|
||||
"store": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"long": {
|
||||
"match": "*_long",
|
||||
"mapping": {
|
||||
"type": "long",
|
||||
"store": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"short": {
|
||||
"match": "*_short",
|
||||
"mapping": {
|
||||
"type": "short",
|
||||
"store": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"numeric": {
|
||||
"match": "*_flt",
|
||||
"mapping": {
|
||||
"type": "float",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"tks": {
|
||||
"match": "*_tks",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"similarity": "scripted_sim",
|
||||
"analyzer": "whitespace",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"ltks":{
|
||||
"match": "*_ltks",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"analyzer": "whitespace",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kwd": {
|
||||
"match_pattern": "regex",
|
||||
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
|
||||
"mapping": {
|
||||
"type": "keyword",
|
||||
"similarity": "boolean",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dt": {
|
||||
"match_pattern": "regex",
|
||||
"match": "^.*(_dt|_time|_at)$",
|
||||
"mapping": {
|
||||
"type": "date",
|
||||
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"nested": {
|
||||
"match": "*_nst",
|
||||
"mapping": {
|
||||
"type": "nested"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"object": {
|
||||
"match": "*_obj",
|
||||
"mapping": {
|
||||
"type": "object",
|
||||
"dynamic": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"string": {
|
||||
"match": "*_with_weight",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"index": "false",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"string": {
|
||||
"match": "*_fea",
|
||||
"mapping": {
|
||||
"type": "rank_feature"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_512_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
"dims": 512
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_768_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
"dims": 768
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_1024_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
"dims": 1024
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_1536_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
"dims": 1536
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"binary": {
|
||||
"match": "*_bin",
|
||||
"mapping": {
|
||||
"type": "binary"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"ulong": {
|
||||
"match": "*_ulong",
|
||||
"mapping": {
|
||||
"type": "unsigned_long",
|
||||
"store": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"long": {
|
||||
"match": "*_long",
|
||||
"mapping": {
|
||||
"type": "long",
|
||||
"store": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"short": {
|
||||
"match": "*_short",
|
||||
"mapping": {
|
||||
"type": "short",
|
||||
"store": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"numeric": {
|
||||
"match": "*_flt",
|
||||
"mapping": {
|
||||
"type": "float",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"tks": {
|
||||
"match": "*_tks",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"similarity": "scripted_sim",
|
||||
"analyzer": "whitespace",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"ltks": {
|
||||
"match": "*_ltks",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"analyzer": "whitespace",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kwd": {
|
||||
"match_pattern": "regex",
|
||||
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
|
||||
"mapping": {
|
||||
"type": "keyword",
|
||||
"similarity": "boolean",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dt": {
|
||||
"match_pattern": "regex",
|
||||
"match": "^.*(_dt|_time|_at)$",
|
||||
"mapping": {
|
||||
"type": "date",
|
||||
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"nested": {
|
||||
"match": "*_nst",
|
||||
"mapping": {
|
||||
"type": "nested"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"object": {
|
||||
"match": "*_obj",
|
||||
"mapping": {
|
||||
"type": "object",
|
||||
"dynamic": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"string": {
|
||||
"match": "*_(with_weight|list)$",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"index": "false",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"string": {
|
||||
"match": "*_fea",
|
||||
"mapping": {
|
||||
"type": "rank_feature"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_512_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
"dims": 512
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_768_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
"dims": 768
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_1024_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
"dims": 1024
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_1536_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine",
|
||||
"dims": 1536
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"binary": {
|
||||
"match": "*_bin",
|
||||
"mapping": {
|
||||
"type": "binary"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
@ -19,6 +19,11 @@ KIBANA_PASSWORD=infini_rag_flow
|
||||
# Update it according to the available memory in the host machine.
|
||||
MEM_LIMIT=8073741824
|
||||
|
||||
# Port to expose Infinity API to the host
|
||||
INFINITY_THRIFT_PORT=23817
|
||||
INFINITY_HTTP_PORT=23820
|
||||
INFINITY_PSQL_PORT=5432
|
||||
|
||||
# The password for MySQL.
|
||||
# When updated, you must revise the `mysql.password` entry in service_conf.yaml.
|
||||
MYSQL_PASSWORD=infini_rag_flow
|
||||
|
@ -6,6 +6,7 @@ services:
|
||||
- esdata01:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- ${ES_PORT}:9200
|
||||
env_file: .env
|
||||
environment:
|
||||
- node.name=es01
|
||||
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
|
||||
@ -27,12 +28,40 @@ services:
|
||||
retries: 120
|
||||
networks:
|
||||
- ragflow
|
||||
restart: always
|
||||
restart: on-failure
|
||||
|
||||
# infinity:
|
||||
# container_name: ragflow-infinity
|
||||
# image: infiniflow/infinity:v0.5.0-dev2
|
||||
# volumes:
|
||||
# - infinity_data:/var/infinity
|
||||
# ports:
|
||||
# - ${INFINITY_THRIFT_PORT}:23817
|
||||
# - ${INFINITY_HTTP_PORT}:23820
|
||||
# - ${INFINITY_PSQL_PORT}:5432
|
||||
# env_file: .env
|
||||
# environment:
|
||||
# - TZ=${TIMEZONE}
|
||||
# mem_limit: ${MEM_LIMIT}
|
||||
# ulimits:
|
||||
# nofile:
|
||||
# soft: 500000
|
||||
# hard: 500000
|
||||
# networks:
|
||||
# - ragflow
|
||||
# healthcheck:
|
||||
# test: ["CMD", "curl", "http://localhost:23820/admin/node/current"]
|
||||
# interval: 10s
|
||||
# timeout: 10s
|
||||
# retries: 120
|
||||
# restart: on-failure
|
||||
|
||||
|
||||
mysql:
|
||||
# mysql:5.7 linux/arm64 image is unavailable.
|
||||
image: mysql:8.0.39
|
||||
container_name: ragflow-mysql
|
||||
env_file: .env
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
|
||||
- TZ=${TIMEZONE}
|
||||
@ -55,7 +84,7 @@ services:
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
restart: always
|
||||
restart: on-failure
|
||||
|
||||
minio:
|
||||
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
|
||||
@ -64,6 +93,7 @@ services:
|
||||
ports:
|
||||
- ${MINIO_PORT}:9000
|
||||
- ${MINIO_CONSOLE_PORT}:9001
|
||||
env_file: .env
|
||||
environment:
|
||||
- MINIO_ROOT_USER=${MINIO_USER}
|
||||
- MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
|
||||
@ -72,25 +102,28 @@ services:
|
||||
- minio_data:/data
|
||||
networks:
|
||||
- ragflow
|
||||
restart: always
|
||||
restart: on-failure
|
||||
|
||||
redis:
|
||||
image: valkey/valkey:8
|
||||
container_name: ragflow-redis
|
||||
command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
|
||||
env_file: .env
|
||||
ports:
|
||||
- ${REDIS_PORT}:6379
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
networks:
|
||||
- ragflow
|
||||
restart: always
|
||||
restart: on-failure
|
||||
|
||||
|
||||
|
||||
volumes:
|
||||
esdata01:
|
||||
driver: local
|
||||
infinity_data:
|
||||
driver: local
|
||||
mysql_data:
|
||||
driver: local
|
||||
minio_data:
|
||||
|
@ -1,6 +1,5 @@
|
||||
include:
|
||||
- path: ./docker-compose-base.yml
|
||||
env_file: ./.env
|
||||
- ./docker-compose-base.yml
|
||||
|
||||
services:
|
||||
ragflow:
|
||||
@ -15,19 +14,21 @@ services:
|
||||
- ${SVR_HTTP_PORT}:9380
|
||||
- 80:80
|
||||
- 443:443
|
||||
- 5678:5678
|
||||
volumes:
|
||||
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
||||
- ./ragflow-logs:/ragflow/logs
|
||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
env_file: .env
|
||||
environment:
|
||||
- TZ=${TIMEZONE}
|
||||
- HF_ENDPOINT=${HF_ENDPOINT}
|
||||
- MACOS=${MACOS}
|
||||
networks:
|
||||
- ragflow
|
||||
restart: always
|
||||
restart: on-failure
|
||||
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
|
@ -67,7 +67,7 @@ docker compose -f docker/docker-compose-base.yml up -d
|
||||
1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
||||
|
||||
```
|
||||
127.0.0.1 es01 mysql minio redis
|
||||
127.0.0.1 es01 infinity mysql minio redis
|
||||
```
|
||||
|
||||
2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
||||
|
@ -1280,7 +1280,7 @@ Success:
|
||||
"document_keyword": "1.txt",
|
||||
"highlight": "<em>ragflow</em> content",
|
||||
"id": "d78435d142bd5cf6704da62c778795c5",
|
||||
"img_id": "",
|
||||
"image_id": "",
|
||||
"important_keywords": [
|
||||
""
|
||||
],
|
||||
|
@ -1351,7 +1351,7 @@ A list of `Chunk` objects representing references to the message, each containin
|
||||
The chunk ID.
|
||||
- `content` `str`
|
||||
The content of the chunk.
|
||||
- `image_id` `str`
|
||||
- `img_id` `str`
|
||||
The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
|
||||
- `document_id` `str`
|
||||
The ID of the referenced document.
|
||||
|
@ -254,9 +254,12 @@ if __name__ == "__main__":
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||
|
||||
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])]
|
||||
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
||||
info = {
|
||||
"input_text": docs,
|
||||
"entity_specs": "organization, person",
|
||||
|
@ -15,95 +15,90 @@
|
||||
#
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict
|
||||
|
||||
import pandas as pd
|
||||
from elasticsearch_dsl import Q, Search
|
||||
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
|
||||
|
||||
from rag.nlp.search import Dealer
|
||||
|
||||
|
||||
class KGSearch(Dealer):
|
||||
def search(self, req, idxnm, emb_mdl=None, highlight=False):
|
||||
def merge_into_first(sres, title=""):
|
||||
df,texts = [],[]
|
||||
for d in sres["hits"]["hits"]:
|
||||
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
|
||||
def merge_into_first(sres, title="") -> Dict[str, str]:
|
||||
if not sres:
|
||||
return {}
|
||||
content_with_weight = ""
|
||||
df, texts = [],[]
|
||||
for d in sres.values():
|
||||
try:
|
||||
df.append(json.loads(d["_source"]["content_with_weight"]))
|
||||
except Exception as e:
|
||||
texts.append(d["_source"]["content_with_weight"])
|
||||
pass
|
||||
if not df and not texts: return False
|
||||
df.append(json.loads(d["content_with_weight"]))
|
||||
except Exception:
|
||||
texts.append(d["content_with_weight"])
|
||||
if df:
|
||||
try:
|
||||
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
|
||||
except Exception as e:
|
||||
pass
|
||||
content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
|
||||
else:
|
||||
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
|
||||
return True
|
||||
content_with_weight = title + "\n" + "\n".join(texts)
|
||||
first_id = ""
|
||||
first_source = {}
|
||||
for k, v in sres.items():
|
||||
first_id = id
|
||||
first_source = deepcopy(v)
|
||||
break
|
||||
first_source["content_with_weight"] = content_with_weight
|
||||
first_id = next(iter(sres))
|
||||
return {first_id: first_source}
|
||||
|
||||
qst = req.get("question", "")
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.05)
|
||||
condition = self.get_filters(req)
|
||||
|
||||
## Entity retrieval
|
||||
condition.update({"knowledge_graph_kwd": ["entity"]})
|
||||
assert emb_mdl, "No embedding model selected"
|
||||
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
|
||||
q_vec = matchDense.embedding_data
|
||||
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
||||
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
|
||||
"doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
|
||||
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
|
||||
"weight_int", "weight_flt", "rank_int"
|
||||
])
|
||||
|
||||
qst = req.get("question", "")
|
||||
binary_query, keywords = self.qryr.question(qst, min_match="5%")
|
||||
binary_query = self._add_filters(binary_query, req)
|
||||
fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
|
||||
|
||||
## Entity retrieval
|
||||
bqry = deepcopy(binary_query)
|
||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
|
||||
s = Search()
|
||||
s = s.query(bqry)[0: 32]
|
||||
|
||||
s = s.to_dict()
|
||||
q_vec = []
|
||||
if req.get("vector"):
|
||||
assert emb_mdl, "No embedding model selected"
|
||||
s["knn"] = self._vector(
|
||||
qst, emb_mdl, req.get(
|
||||
"similarity", 0.1), 1024)
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
q_vec = s["knn"]["query_vector"]
|
||||
|
||||
ent_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
|
||||
entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
|
||||
ent_ids = self.es.getDocIds(ent_res)
|
||||
if merge_into_first(ent_res, "-Entities-"):
|
||||
ent_ids = ent_ids[0:1]
|
||||
ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
|
||||
ent_res_fields = self.dataStore.getFields(ent_res, src)
|
||||
entities = [d["name_kwd"] for d in ent_res_fields.values()]
|
||||
ent_ids = self.dataStore.getChunkIds(ent_res)
|
||||
ent_content = merge_into_first(ent_res_fields, "-Entities-")
|
||||
if ent_content:
|
||||
ent_ids = list(ent_content.keys())
|
||||
|
||||
## Community retrieval
|
||||
bqry = deepcopy(binary_query)
|
||||
bqry.filter.append(Q("terms", entities_kwd=entities))
|
||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
|
||||
s = Search()
|
||||
s = s.query(bqry)[0: 32]
|
||||
s = s.to_dict()
|
||||
comm_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
|
||||
comm_ids = self.es.getDocIds(comm_res)
|
||||
if merge_into_first(comm_res, "-Community Report-"):
|
||||
comm_ids = comm_ids[0:1]
|
||||
condition = self.get_filters(req)
|
||||
condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
|
||||
comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
|
||||
comm_res_fields = self.dataStore.getFields(comm_res, src)
|
||||
comm_ids = self.dataStore.getChunkIds(comm_res)
|
||||
comm_content = merge_into_first(comm_res_fields, "-Community Report-")
|
||||
if comm_content:
|
||||
comm_ids = list(comm_content.keys())
|
||||
|
||||
## Text content retrieval
|
||||
bqry = deepcopy(binary_query)
|
||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
|
||||
s = Search()
|
||||
s = s.query(bqry)[0: 6]
|
||||
s = s.to_dict()
|
||||
txt_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
|
||||
txt_ids = self.es.getDocIds(txt_res)
|
||||
if merge_into_first(txt_res, "-Original Content-"):
|
||||
txt_ids = txt_ids[0:1]
|
||||
condition = self.get_filters(req)
|
||||
condition.update({"knowledge_graph_kwd": ["text"]})
|
||||
txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
|
||||
txt_res_fields = self.dataStore.getFields(txt_res, src)
|
||||
txt_ids = self.dataStore.getChunkIds(txt_res)
|
||||
txt_content = merge_into_first(txt_res_fields, "-Original Content-")
|
||||
if txt_content:
|
||||
txt_ids = list(txt_content.keys())
|
||||
|
||||
return self.SearchResult(
|
||||
total=len(ent_ids) + len(comm_ids) + len(txt_ids),
|
||||
ids=[*ent_ids, *comm_ids, *txt_ids],
|
||||
query_vector=q_vec,
|
||||
aggregation=None,
|
||||
highlight=None,
|
||||
field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
|
||||
field={**ent_content, **comm_content, **txt_content},
|
||||
keywords=[]
|
||||
)
|
||||
|
||||
|
@ -31,10 +31,13 @@ if __name__ == "__main__":
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||
|
||||
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
docs = [d["content_with_weight"] for d in
|
||||
retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=6, fields=["content_with_weight"])]
|
||||
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
||||
graph = ex(docs)
|
||||
|
||||
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
|
871
poetry.lock
generated
871
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -46,22 +46,23 @@ hanziconv = "0.3.2"
|
||||
html-text = "0.6.2"
|
||||
httpx = "0.27.0"
|
||||
huggingface-hub = "^0.25.0"
|
||||
infinity-emb = "0.0.51"
|
||||
infinity-sdk = "0.5.0.dev2"
|
||||
infinity-emb = "^0.0.66"
|
||||
itsdangerous = "2.1.2"
|
||||
markdown = "3.6"
|
||||
markdown-to-json = "2.1.1"
|
||||
minio = "7.2.4"
|
||||
mistralai = "0.4.2"
|
||||
nltk = "3.9.1"
|
||||
numpy = "1.26.4"
|
||||
numpy = "^1.26.0"
|
||||
ollama = "0.2.1"
|
||||
onnxruntime = "1.19.2"
|
||||
openai = "1.45.0"
|
||||
opencv-python = "4.10.0.84"
|
||||
opencv-python-headless = "4.10.0.84"
|
||||
openpyxl = "3.1.2"
|
||||
openpyxl = "^3.1.0"
|
||||
ormsgpack = "1.5.0"
|
||||
pandas = "2.2.2"
|
||||
pandas = "^2.2.0"
|
||||
pdfplumber = "0.10.4"
|
||||
peewee = "3.17.1"
|
||||
pillow = "10.4.0"
|
||||
@ -70,7 +71,7 @@ psycopg2-binary = "2.9.9"
|
||||
pyclipper = "1.3.0.post5"
|
||||
pycryptodomex = "3.20.0"
|
||||
pypdf = "^5.0.0"
|
||||
pytest = "8.2.2"
|
||||
pytest = "^8.3.0"
|
||||
python-dotenv = "1.0.1"
|
||||
python-dateutil = "2.8.2"
|
||||
python-pptx = "^1.0.2"
|
||||
@ -86,7 +87,7 @@ ruamel-base = "1.0.0"
|
||||
scholarly = "1.7.11"
|
||||
scikit-learn = "1.5.0"
|
||||
selenium = "4.22.0"
|
||||
setuptools = "70.0.0"
|
||||
setuptools = "^75.2.0"
|
||||
shapely = "2.0.5"
|
||||
six = "1.16.0"
|
||||
strenum = "0.4.15"
|
||||
@ -115,6 +116,7 @@ pymysql = "^1.1.1"
|
||||
mini-racer = "^0.12.4"
|
||||
pyicu = "^2.13.1"
|
||||
flasgger = "^0.9.7.1"
|
||||
polars = "^1.9.0"
|
||||
|
||||
|
||||
[tool.poetry.group.full]
|
||||
|
@ -20,6 +20,7 @@ from rag.nlp import tokenize, is_english
|
||||
from rag.nlp import rag_tokenizer
|
||||
from deepdoc.parser import PdfParser, PptParser, PlainParser
|
||||
from PyPDF2 import PdfReader as pdf2_read
|
||||
import json
|
||||
|
||||
|
||||
class Ppt(PptParser):
|
||||
@ -107,9 +108,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
d["image"] = img
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
||||
d["page_num_list"] = json.dumps([pn + 1])
|
||||
d["top_list"] = json.dumps([0])
|
||||
d["position_list"] = json.dumps([(pn + 1, 0, img.size[0], 0, img.size[1])])
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
@ -123,10 +124,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
pn += from_page
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [
|
||||
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
||||
d["page_num_list"] = json.dumps([pn + 1])
|
||||
d["top_list"] = json.dumps([0])
|
||||
d["position_list"] = json.dumps([
|
||||
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)])
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
@ -74,7 +74,7 @@ class Excel(ExcelParser):
|
||||
def trans_datatime(s):
|
||||
try:
|
||||
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ def column_data_type(arr):
|
||||
continue
|
||||
try:
|
||||
arr[i] = trans[ty](str(arr[i]))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
arr[i] = None
|
||||
# if ty == "text":
|
||||
# if len(arr) > 128 and uni / len(arr) < 0.1:
|
||||
@ -182,7 +182,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
|
||||
"datetime": "_dt",
|
||||
"bool": "_kwd"}
|
||||
for df in dfs:
|
||||
for n in ["id", "_id", "index", "idx"]:
|
||||
for n in ["id", "index", "idx"]:
|
||||
if n in df.columns:
|
||||
del df[n]
|
||||
clmns = df.columns.values
|
||||
|
196
rag/benchmark.py
196
rag/benchmark.py
@ -15,50 +15,51 @@
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.settings import retrievaler
|
||||
from api.settings import retrievaler, docStoreConn
|
||||
from api.utils import get_uuid
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.nlp import tokenize, search
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from ranx import evaluate
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from ranx import Qrels, Run
|
||||
|
||||
global max_docs
|
||||
max_docs = sys.maxsize
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, kb_id):
|
||||
self.kb_id = kb_id
|
||||
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
self.similarity_threshold = self.kb.similarity_threshold
|
||||
self.vector_similarity_weight = self.kb.vector_similarity_weight
|
||||
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
|
||||
self.tenant_id = ''
|
||||
self.index_name = ''
|
||||
self.initialized_index = False
|
||||
|
||||
def _get_benchmarks(self, query, dataset_idxnm, count=16):
|
||||
|
||||
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
|
||||
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
|
||||
return sres
|
||||
|
||||
def _get_retrieval(self, qrels, dataset_idxnm):
|
||||
def _get_retrieval(self, qrels):
|
||||
# Need to wait for the ES and Infinity index to be ready
|
||||
time.sleep(20)
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
|
||||
ranks = retrievaler.retrieval(query, self.embd_mdl,
|
||||
dataset_idxnm, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
del qrels[query]
|
||||
continue
|
||||
for c in ranks["chunks"]:
|
||||
if "vector" in c:
|
||||
del c["vector"]
|
||||
run[query][c["chunk_id"]] = c["similarity"]
|
||||
|
||||
return run
|
||||
|
||||
def embedding(self, docs, batch_size=16):
|
||||
@ -68,40 +69,37 @@ class Benchmark:
|
||||
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vects.extend(vts.tolist())
|
||||
assert len(docs) == len(vects)
|
||||
vector_size = 0
|
||||
for i, d in enumerate(docs):
|
||||
v = vects[i]
|
||||
vector_size = len(v)
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
return docs
|
||||
return docs, vector_size
|
||||
|
||||
@staticmethod
|
||||
def init_kb(index_name):
|
||||
idxnm = search.index_name(index_name)
|
||||
if ELASTICSEARCH.indexExist(idxnm):
|
||||
ELASTICSEARCH.deleteIdx(search.index_name(index_name))
|
||||
|
||||
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
qrels = defaultdict(dict)
|
||||
texts = defaultdict(dict)
|
||||
docs_count = 0
|
||||
docs = []
|
||||
filelist = os.listdir(file_path)
|
||||
self.init_kb(index_name)
|
||||
|
||||
max_workers = int(os.environ.get('MAX_WORKERS', 3))
|
||||
exe = ThreadPoolExecutor(max_workers=max_workers)
|
||||
threads = []
|
||||
|
||||
def slow_actions(es_docs, idx_nm):
|
||||
es_docs = self.embedding(es_docs)
|
||||
ELASTICSEARCH.bulk(es_docs, idx_nm)
|
||||
return True
|
||||
|
||||
for dir in filelist:
|
||||
data = pd.read_parquet(os.path.join(file_path, dir))
|
||||
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
|
||||
filelist = sorted(os.listdir(file_path))
|
||||
|
||||
for fn in filelist:
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
if not fn.endswith(".parquet"):
|
||||
continue
|
||||
data = pd.read_parquet(os.path.join(file_path, fn))
|
||||
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
query = data.iloc[i]['query']
|
||||
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
||||
d = {
|
||||
@ -115,27 +113,33 @@ class Benchmark:
|
||||
texts[d["id"]] = text
|
||||
qrels[query][d["id"]] = int(rel)
|
||||
if len(docs) >= 32:
|
||||
threads.append(
|
||||
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
docs = []
|
||||
|
||||
threads.append(
|
||||
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
||||
|
||||
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
|
||||
if not threads[i].result().output:
|
||||
print("Indexing error...")
|
||||
|
||||
if docs:
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
return qrels, texts
|
||||
|
||||
def trivia_qa_index(self, file_path, index_name):
|
||||
qrels = defaultdict(dict)
|
||||
texts = defaultdict(dict)
|
||||
docs_count = 0
|
||||
docs = []
|
||||
filelist = os.listdir(file_path)
|
||||
for dir in filelist:
|
||||
data = pd.read_parquet(os.path.join(file_path, dir))
|
||||
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir):
|
||||
filelist = sorted(os.listdir(file_path))
|
||||
for fn in filelist:
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
if not fn.endswith(".parquet"):
|
||||
continue
|
||||
data = pd.read_parquet(os.path.join(file_path, fn))
|
||||
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
query = data.iloc[i]['question']
|
||||
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
||||
data.iloc[i]["search_results"]['search_context']):
|
||||
@ -150,16 +154,18 @@ class Benchmark:
|
||||
texts[d["id"]] = text
|
||||
qrels[query][d["id"]] = int(rel)
|
||||
if len(docs) >= 32:
|
||||
docs = self.embedding(docs)
|
||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs,self.index_name)
|
||||
docs = []
|
||||
|
||||
docs = self.embedding(docs)
|
||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def miracl_index(self, file_path, corpus_path, index_name):
|
||||
|
||||
corpus_total = {}
|
||||
for corpus_file in os.listdir(corpus_path):
|
||||
tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
|
||||
@ -176,14 +182,19 @@ class Benchmark:
|
||||
|
||||
qrels = defaultdict(dict)
|
||||
texts = defaultdict(dict)
|
||||
docs_count = 0
|
||||
docs = []
|
||||
for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
|
||||
if 'test' in qrels_file:
|
||||
continue
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
|
||||
tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
|
||||
names=['qid', 'Q0', 'docid', 'relevance'])
|
||||
for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
|
||||
if docs_count >= max_docs:
|
||||
break
|
||||
query = topics_total[tmp_data.iloc[i]['qid']]
|
||||
text = corpus_total[tmp_data.iloc[i]['docid']]
|
||||
rel = tmp_data.iloc[i]['relevance']
|
||||
@ -198,13 +209,15 @@ class Benchmark:
|
||||
texts[d["id"]] = text
|
||||
qrels[query][d["id"]] = int(rel)
|
||||
if len(docs) >= 32:
|
||||
docs = self.embedding(docs)
|
||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
docs = []
|
||||
|
||||
docs = self.embedding(docs)
|
||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||
@ -213,7 +226,7 @@ class Benchmark:
|
||||
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
||||
key = run_keys[run_i]
|
||||
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
||||
'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
|
||||
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
|
||||
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
||||
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
||||
f.write('## Score For Every Query\n')
|
||||
@ -229,14 +242,18 @@ class Benchmark:
|
||||
|
||||
def __call__(self, dataset, file_path, miracl_corpus=''):
|
||||
if dataset == "ms_marco_v1.1":
|
||||
self.tenant_id = "benchmark_ms_marco_v11"
|
||||
self.index_name = search.index_name(self.tenant_id)
|
||||
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
||||
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
|
||||
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
||||
run = self._get_retrieval(qrels)
|
||||
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
||||
self.save_results(qrels, run, texts, dataset, file_path)
|
||||
if dataset == "trivia_qa":
|
||||
self.tenant_id = "benchmark_trivia_qa"
|
||||
self.index_name = search.index_name(self.tenant_id)
|
||||
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
||||
run = self._get_retrieval(qrels, "benchmark_trivia_qa")
|
||||
print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
||||
run = self._get_retrieval(qrels)
|
||||
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
||||
self.save_results(qrels, run, texts, dataset, file_path)
|
||||
if dataset == "miracl":
|
||||
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
||||
@ -253,28 +270,41 @@ class Benchmark:
|
||||
if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
|
||||
print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
|
||||
continue
|
||||
self.tenant_id = "benchmark_miracl_" + lang
|
||||
self.index_name = search.index_name(self.tenant_id)
|
||||
self.initialized_index = False
|
||||
qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
|
||||
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
||||
"benchmark_miracl_" + lang)
|
||||
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
|
||||
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
||||
run = self._get_retrieval(qrels)
|
||||
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
||||
self.save_results(qrels, run, texts, dataset, file_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('*****************RAGFlow Benchmark*****************')
|
||||
kb_id = input('Please input kb_id:\n')
|
||||
parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
|
||||
parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
|
||||
parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
|
||||
parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
|
||||
parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
|
||||
parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
|
||||
|
||||
args = parser.parse_args()
|
||||
max_docs = args.max_docs
|
||||
kb_id = args.kb_id
|
||||
ex = Benchmark(kb_id)
|
||||
dataset = input(
|
||||
'RAGFlow Benchmark Support:\n\tms_marco_v1.1:<https://huggingface.co/datasets/microsoft/ms_marco>\n\ttrivia_qa:<https://huggingface.co/datasets/mandarjoshi/trivia_qa>\n\tmiracl:<https://huggingface.co/datasets/miracl/miracl>\nPlease input dataset choice:\n')
|
||||
if dataset in ['ms_marco_v1.1', 'trivia_qa']:
|
||||
if dataset == "ms_marco_v1.1":
|
||||
print("Notice: Please provide the ms_marco_v1.1 dataset only. ms_marco_v2.1 is not supported!")
|
||||
dataset_path = input('Please input ' + dataset + ' dataset path:\n')
|
||||
|
||||
dataset = args.dataset
|
||||
dataset_path = args.dataset_path
|
||||
|
||||
if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
|
||||
ex(dataset, dataset_path)
|
||||
elif dataset == 'miracl':
|
||||
dataset_path = input('Please input ' + dataset + ' dataset path:\n')
|
||||
corpus_path = input('Please input ' + dataset + '-corpus dataset path:\n')
|
||||
ex(dataset, dataset_path, miracl_corpus=corpus_path)
|
||||
elif dataset == "miracl":
|
||||
if len(args) < 5:
|
||||
print('Please input the correct parameters!')
|
||||
exit(1)
|
||||
miracl_corpus_path = args[4]
|
||||
ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
|
||||
else:
|
||||
print("Dataset: ", dataset, "not supported!")
|
||||
|
@ -25,6 +25,7 @@ import roman_numbers as r
|
||||
from word2number import w2n
|
||||
from cn2an import cn2an
|
||||
from PIL import Image
|
||||
import json
|
||||
|
||||
all_codecs = [
|
||||
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
||||
@ -51,12 +52,12 @@ def find_codec(blob):
|
||||
try:
|
||||
blob[:1024].decode(c)
|
||||
return c
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
blob.decode(c)
|
||||
return c
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "utf-8"
|
||||
@ -241,7 +242,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
||||
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
||||
add_positions(d, poss)
|
||||
ck = pdf_parser.remove_tag(ck)
|
||||
except NotImplementedError as e:
|
||||
except NotImplementedError:
|
||||
pass
|
||||
tokenize(d, ck, eng)
|
||||
res.append(d)
|
||||
@ -289,13 +290,16 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
def add_positions(d, poss):
|
||||
if not poss:
|
||||
return
|
||||
d["page_num_int"] = []
|
||||
d["position_int"] = []
|
||||
d["top_int"] = []
|
||||
page_num_list = []
|
||||
position_list = []
|
||||
top_list = []
|
||||
for pn, left, right, top, bottom in poss:
|
||||
d["page_num_int"].append(int(pn + 1))
|
||||
d["top_int"].append(int(top))
|
||||
d["position_int"].append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
|
||||
page_num_list.append(int(pn + 1))
|
||||
top_list.append(int(top))
|
||||
position_list.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
|
||||
d["page_num_list"] = json.dumps(page_num_list)
|
||||
d["position_list"] = json.dumps(position_list)
|
||||
d["top_list"] = json.dumps(top_list)
|
||||
|
||||
|
||||
def remove_contents_table(sections, eng=False):
|
||||
|
112
rag/nlp/query.py
112
rag/nlp/query.py
@ -15,20 +15,25 @@
|
||||
#
|
||||
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import logging
|
||||
import copy
|
||||
from elasticsearch_dsl import Q
|
||||
from rag.utils.doc_store_conn import MatchTextExpr
|
||||
|
||||
from rag.nlp import rag_tokenizer, term_weight, synonym
|
||||
|
||||
class EsQueryer:
|
||||
def __init__(self, es):
|
||||
|
||||
class FulltextQueryer:
|
||||
def __init__(self):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.es = es
|
||||
self.syn = synonym.Dealer()
|
||||
self.flds = ["ask_tks^10", "ask_small_tks"]
|
||||
self.query_fields = [
|
||||
"title_tks^10",
|
||||
"title_sm_tks^5",
|
||||
"important_kwd^30",
|
||||
"important_tks^20",
|
||||
"content_ltks^2",
|
||||
"content_sm_ltks",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
@ -43,12 +48,15 @@ class EsQueryer:
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1. / len(arr) >= 0.7
|
||||
return e * 1.0 / len(arr) >= 0.7
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
||||
(
|
||||
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
|
||||
]
|
||||
@ -56,16 +64,16 @@ class EsQueryer:
|
||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||
return txt
|
||||
|
||||
def question(self, txt, tbl="qa", min_match="60%"):
|
||||
def question(self, txt, tbl="qa", min_match:float=0.6):
|
||||
txt = re.sub(
|
||||
r"[ :\r\n\t,,。??/`!!&\^%%]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(
|
||||
rag_tokenizer.strQ2B(
|
||||
txt.lower()))).strip()
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
txt = EsQueryer.rmWWW(txt)
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split(" ")
|
||||
tks_w = self.tw.weights(tks)
|
||||
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
||||
@ -73,14 +81,20 @@ class EsQueryer:
|
||||
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
||||
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
||||
for i in range(1, len(tks_w)):
|
||||
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
||||
q.append(
|
||||
'"%s %s"^%.4f'
|
||||
% (
|
||||
tks_w[i - 1][0],
|
||||
tks_w[i][0],
|
||||
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
|
||||
)
|
||||
)
|
||||
if not q:
|
||||
q.append(txt)
|
||||
return Q("bool",
|
||||
must=Q("query_string", fields=self.flds,
|
||||
type="best_fields", query=" ".join(q),
|
||||
boost=1)#, minimum_should_match=min_match)
|
||||
), list(set([t for t in txt.split(" ") if t]))
|
||||
query = " ".join(q)
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100
|
||||
), tks
|
||||
|
||||
def need_fine_grained_tokenize(tk):
|
||||
if len(tk) < 3:
|
||||
@ -89,7 +103,7 @@ class EsQueryer:
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = EsQueryer.rmWWW(txt)
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt)[:256]: # .split(" "):
|
||||
if not tt:
|
||||
@ -101,65 +115,71 @@ class EsQueryer:
|
||||
logging.info(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
sm = rag_tokenizer.fine_grained_tokenize(tk).split(" ") if need_fine_grained_tokenize(tk) else []
|
||||
sm = (
|
||||
rag_tokenizer.fine_grained_tokenize(tk).split(" ")
|
||||
if need_fine_grained_tokenize(tk)
|
||||
else []
|
||||
)
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
"",
|
||||
m) for m in sm]
|
||||
sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
m,
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
keywords.extend(sm)
|
||||
if len(keywords) >= 12: break
|
||||
if len(keywords) >= 12:
|
||||
break
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk = EsQueryer.subSpecialChar(tk)
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = "\"%s\"" % tk
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} %s)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
|
||||
" ".join(sm), " ".join(sm))
|
||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||
if tk.strip():
|
||||
tms.append((tk, w))
|
||||
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
if len(twts) > 1:
|
||||
tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
|
||||
tms += ' ("%s"~4)^1.5' % (" ".join([t for t, _ in twts]))
|
||||
if re.match(r"[0-9a-z ]+$", tt):
|
||||
tms = f"(\"{tt}\" OR \"%s\")" % rag_tokenizer.tokenize(tt)
|
||||
tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
|
||||
|
||||
syns = " OR ".join(
|
||||
["\"%s\"^0.7" % EsQueryer.subSpecialChar(rag_tokenizer.tokenize(s)) for s in syns])
|
||||
[
|
||||
'"%s"^0.7'
|
||||
% FulltextQueryer.subSpecialChar(rag_tokenizer.tokenize(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
if syns:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
|
||||
qs.append(tms)
|
||||
|
||||
flds = copy.deepcopy(self.flds)
|
||||
mst = []
|
||||
if qs:
|
||||
mst.append(
|
||||
Q("query_string", fields=flds, type="best_fields",
|
||||
query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
|
||||
)
|
||||
query = " OR ".join([f"({t})" for t in qs if t])
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"minimum_should_match": min_match}
|
||||
), keywords
|
||||
return None, keywords
|
||||
|
||||
return Q("bool",
|
||||
must=mst,
|
||||
), list(set(keywords))
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
|
||||
vtweight=0.7):
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
import numpy as np
|
||||
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
return np.array(sims[0]) * vtweight + \
|
||||
np.array(tksim) * tkweight, tksim, sims[0]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
def toDict(tks):
|
||||
|
@ -14,34 +14,25 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from elasticsearch_dsl import Q, Search
|
||||
import json
|
||||
from typing import List, Optional, Dict, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.settings import es_logger
|
||||
from rag.settings import doc_store_logger
|
||||
from rag.utils import rmSpace
|
||||
from rag.nlp import rag_tokenizer, query, is_english
|
||||
from rag.nlp import rag_tokenizer, query
|
||||
import numpy as np
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
|
||||
|
||||
def index_name(uid): return f"ragflow_{uid}"
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, es):
|
||||
self.qryr = query.EsQueryer(es)
|
||||
self.qryr.flds = [
|
||||
"title_tks^10",
|
||||
"title_sm_tks^5",
|
||||
"important_kwd^30",
|
||||
"important_tks^20",
|
||||
"content_ltks^2",
|
||||
"content_sm_ltks"]
|
||||
self.es = es
|
||||
def __init__(self, dataStore: DocStoreConnection):
|
||||
self.qryr = query.FulltextQueryer()
|
||||
self.dataStore = dataStore
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
@ -54,170 +45,99 @@ class Dealer:
|
||||
keywords: Optional[List[str]] = None
|
||||
group_docs: List[List] = None
|
||||
|
||||
def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
|
||||
qv, c = emb_mdl.encode_queries(txt)
|
||||
return {
|
||||
"field": "q_%d_vec" % len(qv),
|
||||
"k": topk,
|
||||
"similarity": sim,
|
||||
"num_candidates": topk * 2,
|
||||
"query_vector": [float(v) for v in qv]
|
||||
}
|
||||
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
embedding_data = [float(v) for v in qv]
|
||||
vector_column_name = f"q_{len(embedding_data)}_vec"
|
||||
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
||||
|
||||
def _add_filters(self, bqry, req):
|
||||
if req.get("kb_ids"):
|
||||
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
||||
if req.get("doc_ids"):
|
||||
bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
|
||||
if req.get("knowledge_graph_kwd"):
|
||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"]))
|
||||
if "available_int" in req:
|
||||
if req["available_int"] == 0:
|
||||
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
||||
else:
|
||||
bqry.filter.append(
|
||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
||||
return bqry
|
||||
def get_filters(self, req):
|
||||
condition = dict()
|
||||
for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
|
||||
if key in req and req[key] is not None:
|
||||
condition[field] = req[key]
|
||||
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
||||
for key in ["knowledge_graph_kwd"]:
|
||||
if key in req and req[key] is not None:
|
||||
condition[key] = req[key]
|
||||
return condition
|
||||
|
||||
def search(self, req, idxnms, emb_mdl=None, highlight=False):
|
||||
qst = req.get("question", "")
|
||||
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
||||
bqry = self._add_filters(bqry, req)
|
||||
bqry.boost = 0.05
|
||||
def search(self, req, idx_names: list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
|
||||
filters = self.get_filters(req)
|
||||
orderBy = OrderByExpr()
|
||||
|
||||
s = Search()
|
||||
pg = int(req.get("page", 1)) - 1
|
||||
topk = int(req.get("topk", 1024))
|
||||
ps = int(req.get("size", topk))
|
||||
offset, limit = pg * ps, (pg + 1) * ps
|
||||
|
||||
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
||||
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd",
|
||||
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
|
||||
|
||||
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
|
||||
s = s.highlight("content_ltks")
|
||||
s = s.highlight("title_ltks")
|
||||
if not qst:
|
||||
if not req.get("sort"):
|
||||
s = s.sort(
|
||||
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
|
||||
{"create_timestamp_flt": {
|
||||
"order": "desc", "unmapped_type": "float"}}
|
||||
)
|
||||
else:
|
||||
s = s.sort(
|
||||
{"page_num_int": {"order": "asc", "unmapped_type": "float",
|
||||
"mode": "avg", "numeric_type": "double"}},
|
||||
{"top_int": {"order": "asc", "unmapped_type": "float",
|
||||
"mode": "avg", "numeric_type": "double"}},
|
||||
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
|
||||
{"create_timestamp_flt": {
|
||||
"order": "desc", "unmapped_type": "float"}}
|
||||
)
|
||||
|
||||
if qst:
|
||||
s = s.highlight_options(
|
||||
fragment_size=120,
|
||||
number_of_fragments=5,
|
||||
boundary_scanner_locale="zh-CN",
|
||||
boundary_scanner="SENTENCE",
|
||||
boundary_chars=",./;:\\!(),。?:!……()——、"
|
||||
)
|
||||
s = s.to_dict()
|
||||
q_vec = []
|
||||
if req.get("vector"):
|
||||
assert emb_mdl, "No embedding model selected"
|
||||
s["knn"] = self._vector(
|
||||
qst, emb_mdl, req.get(
|
||||
"similarity", 0.1), topk)
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
if not highlight and "highlight" in s:
|
||||
del s["highlight"]
|
||||
q_vec = s["knn"]["query_vector"]
|
||||
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
||||
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
|
||||
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
||||
if self.es.getTotal(res) == 0 and "knn" in s:
|
||||
bqry, _ = self.qryr.question(qst, min_match="10%")
|
||||
if req.get("doc_ids"):
|
||||
bqry = Q("bool", must=[])
|
||||
bqry = self._add_filters(bqry, req)
|
||||
s["query"] = bqry.to_dict()
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
s["knn"]["similarity"] = 0.17
|
||||
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
|
||||
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
||||
|
||||
"doc_id", "position_list", "knowledge_graph_kwd",
|
||||
"available_int", "content_with_weight"])
|
||||
kwds = set([])
|
||||
for k in keywords:
|
||||
kwds.add(k)
|
||||
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
|
||||
if len(kk) < 2:
|
||||
continue
|
||||
if kk in kwds:
|
||||
continue
|
||||
kwds.add(kk)
|
||||
|
||||
aggs = self.getAggregation(res, "docnm_kwd")
|
||||
qst = req.get("question", "")
|
||||
q_vec = []
|
||||
if not qst:
|
||||
if req.get("sort"):
|
||||
orderBy.desc("create_timestamp_flt")
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total=self.dataStore.getTotal(res)
|
||||
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||
if emb_mdl is None:
|
||||
matchExprs = [matchText]
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
|
||||
total=self.dataStore.getTotal(res)
|
||||
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
q_vec = matchDense.embedding_data
|
||||
src.append(f"q_{len(q_vec)}_vec")
|
||||
|
||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
||||
matchExprs = [matchText, matchDense, fusionExpr]
|
||||
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
|
||||
total=self.dataStore.getTotal(res)
|
||||
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
if "doc_ids" in filters:
|
||||
del filters["doc_ids"]
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total=self.dataStore.getTotal(res)
|
||||
doc_store_logger.info("Dealer.search 2 TOTAL: {}".format(total))
|
||||
|
||||
for k in keywords:
|
||||
kwds.add(k)
|
||||
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
|
||||
if len(kk) < 2:
|
||||
continue
|
||||
if kk in kwds:
|
||||
continue
|
||||
kwds.add(kk)
|
||||
|
||||
doc_store_logger.info(f"TOTAL: {total}")
|
||||
ids=self.dataStore.getChunkIds(res)
|
||||
keywords=list(kwds)
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
||||
return self.SearchResult(
|
||||
total=self.es.getTotal(res),
|
||||
ids=self.es.getDocIds(res),
|
||||
total=total,
|
||||
ids=ids,
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=self.getHighlight(res, keywords, "content_with_weight"),
|
||||
field=self.getFields(res, src),
|
||||
keywords=list(kwds)
|
||||
highlight=highlight,
|
||||
field=self.dataStore.getFields(res, src),
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
def getAggregation(self, res, g):
|
||||
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
|
||||
return
|
||||
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
def getHighlight(self, res, keywords, fieldnm):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
if not hlts:
|
||||
continue
|
||||
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
||||
if not is_english(txt.split(" ")):
|
||||
ans[d["_id"]] = txt
|
||||
continue
|
||||
|
||||
txt = d["_source"][fieldnm]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
|
||||
txts.append(t)
|
||||
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
||||
|
||||
return ans
|
||||
|
||||
def getFields(self, sres, flds):
|
||||
res = {}
|
||||
if not flds:
|
||||
return {}
|
||||
for d in self.es.getSource(sres):
|
||||
m = {n: d.get(n) for n in flds if d.get(n) is not None}
|
||||
for n, v in m.items():
|
||||
if isinstance(v, type([])):
|
||||
m[n] = "\t".join([str(vv) if not isinstance(
|
||||
vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v])
|
||||
continue
|
||||
if not isinstance(v, type("")):
|
||||
m[n] = str(m[n])
|
||||
#if n.find("tks") > 0:
|
||||
# m[n] = rmSpace(m[n])
|
||||
|
||||
if m:
|
||||
res[d["id"]] = m
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def trans2floats(txt):
|
||||
return [float(t) for t in txt.split("\t")]
|
||||
@ -260,7 +180,7 @@ class Dealer:
|
||||
continue
|
||||
idx.append(i)
|
||||
pieces_.append(t)
|
||||
es_logger.info("{} => {}".format(answer, pieces_))
|
||||
doc_store_logger.info("{} => {}".format(answer, pieces_))
|
||||
if not pieces_:
|
||||
return answer, set([])
|
||||
|
||||
@ -281,7 +201,7 @@ class Dealer:
|
||||
chunks_tks,
|
||||
tkweight, vtweight)
|
||||
mx = np.max(sim) * 0.99
|
||||
es_logger.info("{} SIM: {}".format(pieces_[i], mx))
|
||||
doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx))
|
||||
if mx < thr:
|
||||
continue
|
||||
cites[idx[i]] = list(
|
||||
@ -309,9 +229,15 @@ class Dealer:
|
||||
def rerank(self, sres, query, tkweight=0.3,
|
||||
vtweight=0.7, cfield="content_ltks"):
|
||||
_, keywords = self.qryr.question(query)
|
||||
ins_embd = [
|
||||
Dealer.trans2floats(
|
||||
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
|
||||
vector_size = len(sres.query_vector)
|
||||
vector_column = f"q_{vector_size}_vec"
|
||||
zero_vector = [0.0] * vector_size
|
||||
ins_embd = []
|
||||
for chunk_id in sres.ids:
|
||||
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
||||
if isinstance(vector, str):
|
||||
vector = [float(v) for v in vector.split("\t")]
|
||||
ins_embd.append(vector)
|
||||
if not ins_embd:
|
||||
return [], [], []
|
||||
|
||||
@ -377,7 +303,7 @@ class Dealer:
|
||||
if isinstance(tenant_ids, str):
|
||||
tenant_ids = tenant_ids.split(",")
|
||||
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
|
||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
|
||||
ranks["total"] = sres.total
|
||||
|
||||
if page <= RERANK_PAGE_LIMIT:
|
||||
@ -393,6 +319,8 @@ class Dealer:
|
||||
idx = list(range(len(sres.ids)))
|
||||
|
||||
dim = len(sres.query_vector)
|
||||
vector_column = f"q_{dim}_vec"
|
||||
zero_vector = [0.0] * dim
|
||||
for i in idx:
|
||||
if sim[i] < similarity_threshold:
|
||||
break
|
||||
@ -401,34 +329,32 @@ class Dealer:
|
||||
continue
|
||||
break
|
||||
id = sres.ids[i]
|
||||
dnm = sres.field[id]["docnm_kwd"]
|
||||
did = sres.field[id]["doc_id"]
|
||||
chunk = sres.field[id]
|
||||
dnm = chunk["docnm_kwd"]
|
||||
did = chunk["doc_id"]
|
||||
position_list = chunk.get("position_list", "[]")
|
||||
if not position_list:
|
||||
position_list = "[]"
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_ltks": sres.field[id]["content_ltks"],
|
||||
"content_with_weight": sres.field[id]["content_with_weight"],
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"content_ltks": chunk["content_ltks"],
|
||||
"content_with_weight": chunk["content_with_weight"],
|
||||
"doc_id": chunk["doc_id"],
|
||||
"docnm_kwd": dnm,
|
||||
"kb_id": sres.field[id]["kb_id"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
"img_id": sres.field[id].get("img_id", ""),
|
||||
"kb_id": chunk["kb_id"],
|
||||
"important_kwd": chunk.get("important_kwd", []),
|
||||
"image_id": chunk.get("img_id", ""),
|
||||
"similarity": sim[i],
|
||||
"vector_similarity": vsim[i],
|
||||
"term_similarity": tsim[i],
|
||||
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
|
||||
"positions": sres.field[id].get("position_int", "").split("\t")
|
||||
"vector": chunk.get(vector_column, zero_vector),
|
||||
"positions": json.loads(position_list)
|
||||
}
|
||||
if highlight:
|
||||
if id in sres.highlight:
|
||||
d["highlight"] = rmSpace(sres.highlight[id])
|
||||
else:
|
||||
d["highlight"] = d["content_with_weight"]
|
||||
if len(d["positions"]) % 5 == 0:
|
||||
poss = []
|
||||
for i in range(0, len(d["positions"]), 5):
|
||||
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
||||
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
||||
d["positions"] = poss
|
||||
ranks["chunks"].append(d)
|
||||
if dnm not in ranks["doc_aggs"]:
|
||||
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
||||
@ -442,39 +368,11 @@ class Dealer:
|
||||
return ranks
|
||||
|
||||
def sql_retrieval(self, sql, fetch_size=128, format="json"):
|
||||
from api.settings import chat_logger
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
es_logger.info(f"Get es sql: {sql}")
|
||||
replaces = []
|
||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||
fld, v = r.group(1), r.group(3)
|
||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||
replaces.append(
|
||||
("{}{}'{}'".format(
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
tbl = self.dataStore.sql(sql, fetch_size, format)
|
||||
return tbl
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
chat_logger.info(f"To es: {sql}")
|
||||
|
||||
try:
|
||||
tbl = self.es.sql(sql, fetch_size, format)
|
||||
return tbl
|
||||
except Exception as e:
|
||||
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||
s = Search()
|
||||
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
||||
s = s.to_dict()
|
||||
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
|
||||
res = []
|
||||
for index, chunk in enumerate(es_res['hits']['hits']):
|
||||
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
||||
return res
|
||||
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||
condition = {"doc_id": doc_id}
|
||||
res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids)
|
||||
dict_chunks = self.dataStore.getFields(res, fields)
|
||||
return dict_chunks.values()
|
||||
|
@ -25,12 +25,13 @@ RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
SUBPROCESS_STD_LOG_NAME = "std.log"
|
||||
|
||||
ES = get_base_config("es", {})
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
AZURE = get_base_config("azure", {})
|
||||
S3 = get_base_config("s3", {})
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
try:
|
||||
REDIS = decrypt_database_config(name="redis")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
REDIS = {}
|
||||
pass
|
||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||
@ -44,7 +45,7 @@ LoggerFactory.set_directory(
|
||||
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
||||
LoggerFactory.LEVEL = 30
|
||||
|
||||
es_logger = getLogger("es")
|
||||
doc_store_logger = getLogger("doc_store")
|
||||
minio_logger = getLogger("minio")
|
||||
s3_logger = getLogger("s3")
|
||||
azure_logger = getLogger("azure")
|
||||
@ -53,7 +54,7 @@ chunk_logger = getLogger("chunk_logger")
|
||||
database_logger = getLogger("database")
|
||||
|
||||
formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
|
||||
for logger in [es_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
|
||||
for logger in [doc_store_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
|
||||
logger.setLevel(logging.INFO)
|
||||
for handler in logger.handlers:
|
||||
handler.setFormatter(fmt=formatter)
|
||||
|
@ -31,7 +31,6 @@ from timeit import default_timer as timer
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from elasticsearch_dsl import Q
|
||||
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.dialog_service import keyword_extraction, question_proposal
|
||||
@ -39,8 +38,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.settings import retrievaler
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from api.settings import retrievaler, docStoreConn
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
@ -48,7 +46,6 @@ from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as
|
||||
from rag.settings import database_logger, SVR_QUEUE_NAME
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
from rag.utils import rmSpace, num_tokens_from_string
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
@ -126,7 +123,7 @@ def collect():
|
||||
return pd.DataFrame()
|
||||
tasks = TaskService.get_tasks(msg["id"])
|
||||
if not tasks:
|
||||
cron_logger.warn("{} empty task!".format(msg["id"]))
|
||||
cron_logger.warning("{} empty task!".format(msg["id"]))
|
||||
return []
|
||||
|
||||
tasks = pd.DataFrame(tasks)
|
||||
@ -187,7 +184,7 @@ def build(row):
|
||||
docs = []
|
||||
doc = {
|
||||
"doc_id": row["doc_id"],
|
||||
"kb_id": [str(row["kb_id"])]
|
||||
"kb_id": str(row["kb_id"])
|
||||
}
|
||||
el = 0
|
||||
for ck in cks:
|
||||
@ -196,10 +193,14 @@ def build(row):
|
||||
md5 = hashlib.md5()
|
||||
md5.update((ck["content_with_weight"] +
|
||||
str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["id"] = md5.hexdigest()
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
if not d.get("image"):
|
||||
d["img_id"] = ""
|
||||
d["page_num_list"] = json.dumps([])
|
||||
d["position_list"] = json.dumps([])
|
||||
d["top_list"] = json.dumps([])
|
||||
docs.append(d)
|
||||
continue
|
||||
|
||||
@ -211,13 +212,13 @@ def build(row):
|
||||
d["image"].save(output_buffer, format='JPEG')
|
||||
|
||||
st = timer()
|
||||
STORAGE_IMPL.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
||||
el += timer() - st
|
||||
except Exception as e:
|
||||
cron_logger.error(str(e))
|
||||
traceback.print_exc()
|
||||
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
||||
del d["image"]
|
||||
docs.append(d)
|
||||
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
|
||||
@ -245,12 +246,9 @@ def build(row):
|
||||
return docs
|
||||
|
||||
|
||||
def init_kb(row):
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
if ELASTICSEARCH.indexExist(idxnm):
|
||||
return
|
||||
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
||||
|
||||
|
||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@ -288,17 +286,20 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
cnts) if len(tts) == len(cnts) else cnts
|
||||
|
||||
assert len(vects) == len(docs)
|
||||
vector_size = 0
|
||||
for i, d in enumerate(docs):
|
||||
v = vects[i].tolist()
|
||||
vector_size = len(v)
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
return tk_count
|
||||
return tk_count, vector_size
|
||||
|
||||
|
||||
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
vts, _ = embd_mdl.encode(["ok"])
|
||||
vctr_nm = "q_%d_vec" % len(vts[0])
|
||||
vector_size = len(vts[0])
|
||||
vctr_nm = "q_%d_vec" % vector_size
|
||||
chunks = []
|
||||
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
||||
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
raptor = Raptor(
|
||||
@ -323,7 +324,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
d = copy.deepcopy(doc)
|
||||
md5 = hashlib.md5()
|
||||
md5.update((content + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["id"] = md5.hexdigest()
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
d[vctr_nm] = vctr.tolist()
|
||||
@ -332,7 +333,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
res.append(d)
|
||||
tk_count += num_tokens_from_string(content)
|
||||
return res, tk_count
|
||||
return res, tk_count, vector_size
|
||||
|
||||
|
||||
def main():
|
||||
@ -352,7 +353,7 @@ def main():
|
||||
if r.get("task_type", "") == "raptor":
|
||||
try:
|
||||
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
|
||||
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
|
||||
cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
|
||||
except Exception as e:
|
||||
callback(-1, msg=str(e))
|
||||
cron_logger.error(str(e))
|
||||
@ -373,7 +374,7 @@ def main():
|
||||
len(cks))
|
||||
st = timer()
|
||||
try:
|
||||
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
||||
tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
|
||||
except Exception as e:
|
||||
callback(-1, "Embedding error:{}".format(str(e)))
|
||||
cron_logger.error(str(e))
|
||||
@ -381,26 +382,25 @@ def main():
|
||||
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
||||
|
||||
init_kb(r)
|
||||
chunk_count = len(set([c["_id"] for c in cks]))
|
||||
# cron_logger.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
|
||||
init_kb(r, vector_size)
|
||||
chunk_count = len(set([c["id"] for c in cks]))
|
||||
st = timer()
|
||||
es_r = ""
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
|
||||
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
if b % 128 == 0:
|
||||
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
||||
|
||||
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
if es_r:
|
||||
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
||||
cron_logger.error(str(es_r))
|
||||
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
cron_logger.error('Insert chunk error: ' + str(es_r))
|
||||
else:
|
||||
if TaskService.do_cancel(r["id"]):
|
||||
ELASTICSEARCH.deleteByQuery(
|
||||
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
||||
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
continue
|
||||
callback(1., "Done!")
|
||||
DocumentService.increment_chunk_num(
|
||||
|
251
rag/utils/doc_store_conn.py
Normal file
251
rag/utils/doc_store_conn.py
Normal file
@ -0,0 +1,251 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from typing import List, Dict
|
||||
|
||||
DEFAULT_MATCH_VECTOR_TOPN = 10
|
||||
DEFAULT_MATCH_SPARSE_TOPN = 10
|
||||
VEC = Union[list, np.ndarray]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseVector:
|
||||
indices: list[int]
|
||||
values: Union[list[float], list[int], None] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert (self.values is None) or (len(self.indices) == len(self.values))
|
||||
|
||||
def to_dict_old(self):
|
||||
d = {"indices": self.indices}
|
||||
if self.values is not None:
|
||||
d["values"] = self.values
|
||||
return d
|
||||
|
||||
def to_dict(self):
|
||||
if self.values is None:
|
||||
raise ValueError("SparseVector.values is None")
|
||||
result = {}
|
||||
for i, v in zip(self.indices, self.values):
|
||||
result[str(i)] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return SparseVector(d["indices"], d.get("values"))
|
||||
|
||||
def __str__(self):
|
||||
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class MatchTextExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
fields: str,
|
||||
matching_text: str,
|
||||
topn: int,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.fields = fields
|
||||
self.matching_text = matching_text
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchDenseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
embedding_data: VEC,
|
||||
embedding_data_type: str,
|
||||
distance_type: str,
|
||||
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.embedding_data = embedding_data
|
||||
self.embedding_data_type = embedding_data_type
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchSparseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
sparse_data: SparseVector | dict,
|
||||
distance_type: str,
|
||||
topn: int,
|
||||
opt_params: Optional[dict] = None,
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.sparse_data = sparse_data
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.opt_params = opt_params
|
||||
|
||||
|
||||
class MatchTensorExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
column_name: str,
|
||||
query_data: VEC,
|
||||
query_data_type: str,
|
||||
topn: int,
|
||||
extra_option: Optional[dict] = None,
|
||||
):
|
||||
self.column_name = column_name
|
||||
self.query_data = query_data
|
||||
self.query_data_type = query_data_type
|
||||
self.topn = topn
|
||||
self.extra_option = extra_option
|
||||
|
||||
|
||||
class FusionExpr(ABC):
|
||||
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
|
||||
self.method = method
|
||||
self.topn = topn
|
||||
self.fusion_params = fusion_params
|
||||
|
||||
|
||||
MatchExpr = Union[
|
||||
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
|
||||
]
|
||||
|
||||
|
||||
class OrderByExpr(ABC):
|
||||
def __init__(self):
|
||||
self.fields = list()
|
||||
def asc(self, field: str):
|
||||
self.fields.append((field, 0))
|
||||
return self
|
||||
def desc(self, field: str):
|
||||
self.fields.append((field, 1))
|
||||
return self
|
||||
def fields(self):
|
||||
return self.fields
|
||||
|
||||
class DocStoreConnection(ABC):
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dbType(self) -> str:
|
||||
"""
|
||||
Return the type of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
"""
|
||||
Delete an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Check if an index with given name exists
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
|
||||
) -> list[dict] | pl.DataFrame:
|
||||
"""
|
||||
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
"""
|
||||
Get single chunk with given id
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
||||
"""
|
||||
Update or insert a bulk of rows
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Update rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
"""
|
||||
Delete rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def getTotal(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getChunkIds(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
@abstractmethod
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
"""
|
||||
Run the sql generated by text-to-sql
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
@ -1,29 +1,29 @@
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import copy
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
import elasticsearch
|
||||
from elastic_transport import ConnectionTimeout
|
||||
import copy
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch_dsl import UpdateByQuery, Search, Index
|
||||
from rag.settings import es_logger
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from rag.settings import doc_store_logger
|
||||
from rag import settings
|
||||
from rag.utils import singleton
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
import polars as pl
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
|
||||
es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
|
||||
doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
|
||||
|
||||
|
||||
@singleton
|
||||
class ESConnection:
|
||||
class ESConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
self.conn()
|
||||
self.idxnm = settings.ES.get("index_name", "")
|
||||
if not self.es.ping():
|
||||
raise Exception("Can't connect to ES cluster")
|
||||
|
||||
def conn(self):
|
||||
for _ in range(10):
|
||||
try:
|
||||
self.es = Elasticsearch(
|
||||
@ -34,390 +34,317 @@ class ESConnection:
|
||||
)
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
es_logger.info("Connect to es.")
|
||||
doc_store_logger.info("Connect to es.")
|
||||
break
|
||||
except Exception as e:
|
||||
es_logger.error("Fail to connect to es: " + str(e))
|
||||
doc_store_logger.error("Fail to connect to es: " + str(e))
|
||||
time.sleep(1)
|
||||
|
||||
def version(self):
|
||||
if not self.es.ping():
|
||||
raise Exception("Can't connect to ES cluster")
|
||||
v = self.info.get("version", {"number": "5.6"})
|
||||
v = v["number"].split(".")[0]
|
||||
return int(v) >= 7
|
||||
if int(v) < 8:
|
||||
raise Exception(f"ES version must be greater than or equal to 8, current version: {v}")
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
|
||||
def health(self):
|
||||
return dict(self.es.cluster.health())
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
def dbType(self) -> str:
|
||||
return "elasticsearch"
|
||||
|
||||
def upsert(self, df, idxnm=""):
|
||||
res = []
|
||||
for d in df:
|
||||
id = d["id"]
|
||||
del d["id"]
|
||||
d = {"doc": d, "doc_as_upsert": "true"}
|
||||
T = False
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.update(
|
||||
index=(
|
||||
self.idxnm if not idxnm else idxnm),
|
||||
body=d,
|
||||
id=id,
|
||||
doc_type="doc",
|
||||
refresh=True,
|
||||
retry_on_conflict=100)
|
||||
else:
|
||||
r = self.es.update(
|
||||
index=(
|
||||
self.idxnm if not idxnm else idxnm),
|
||||
body=d,
|
||||
id=id,
|
||||
refresh=True,
|
||||
retry_on_conflict=100)
|
||||
es_logger.info("Successfully upsert: %s" % id)
|
||||
T = True
|
||||
break
|
||||
except Exception as e:
|
||||
es_logger.warning("Fail to index: " +
|
||||
json.dumps(d, ensure_ascii=False) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
T = False
|
||||
def health(self) -> dict:
|
||||
return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
|
||||
|
||||
if not T:
|
||||
res.append(d)
|
||||
es_logger.error(
|
||||
"Fail to index: " +
|
||||
re.sub(
|
||||
"[\r\n]",
|
||||
"",
|
||||
json.dumps(
|
||||
d,
|
||||
ensure_ascii=False)))
|
||||
d["id"] = id
|
||||
d["_index"] = self.idxnm
|
||||
|
||||
if not res:
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.indexExist(indexName, knowledgebaseId):
|
||||
return True
|
||||
return False
|
||||
try:
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=indexName,
|
||||
settings=self.mapping["settings"],
|
||||
mappings=self.mapping["mappings"])
|
||||
except Exception as e:
|
||||
doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e)))
|
||||
|
||||
def bulk(self, df, idx_nm=None):
|
||||
ids, acts = {}, []
|
||||
for d in df:
|
||||
id = d["id"] if "id" in d else d["_id"]
|
||||
ids[id] = copy.deepcopy(d)
|
||||
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
|
||||
if "id" in d:
|
||||
del d["id"]
|
||||
if "_id" in d:
|
||||
del d["_id"]
|
||||
acts.append(
|
||||
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
|
||||
acts.append({"doc": d, "doc_as_upsert": "true"})
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
try:
|
||||
return self.es.indices.delete(indexName, allow_no_indices=True)
|
||||
except Exception as e:
|
||||
doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e)))
|
||||
|
||||
res = []
|
||||
for _ in range(100):
|
||||
try:
|
||||
if elasticsearch.__version__[0] < 8:
|
||||
r = self.es.bulk(
|
||||
index=(
|
||||
self.idxnm if not idx_nm else idx_nm),
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s")
|
||||
else:
|
||||
r = self.es.bulk(index=(self.idxnm if not idx_nm else
|
||||
idx_nm), operations=acts,
|
||||
refresh=False, timeout="600s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for it in r["items"]:
|
||||
if "error" in it["update"]:
|
||||
res.append(str(it["update"]["_id"]) +
|
||||
":" + str(it["update"]["error"]))
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
es_logger.warn("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return res
|
||||
|
||||
def bulk4script(self, df):
|
||||
ids, acts = {}, []
|
||||
for d in df:
|
||||
id = d["id"]
|
||||
ids[id] = copy.deepcopy(d["raw"])
|
||||
acts.append({"update": {"_id": id, "_index": self.idxnm}})
|
||||
acts.append(d["script"])
|
||||
es_logger.info("bulk upsert: %s" % id)
|
||||
|
||||
res = []
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.bulk(
|
||||
index=self.idxnm,
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s",
|
||||
doc_type="doc")
|
||||
else:
|
||||
r = self.es.bulk(
|
||||
index=self.idxnm,
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for it in r["items"]:
|
||||
if "error" in it["update"]:
|
||||
res.append(str(it["update"]["_id"]))
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
es_logger.warning("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return res
|
||||
|
||||
def rm(self, d):
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.delete(
|
||||
index=self.idxnm,
|
||||
id=d["id"],
|
||||
doc_type="doc",
|
||||
refresh=True)
|
||||
else:
|
||||
r = self.es.delete(
|
||||
index=self.idxnm,
|
||||
id=d["id"],
|
||||
refresh=True,
|
||||
doc_type="_doc")
|
||||
es_logger.info("Remove %s" % d["id"])
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.warn("Fail to delete: " + str(d) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return True
|
||||
self.conn()
|
||||
|
||||
es_logger.error("Fail to delete: " + str(d))
|
||||
|
||||
return False
|
||||
|
||||
def search(self, q, idxnms=None, src=False, timeout="2s"):
|
||||
if not isinstance(q, dict):
|
||||
q = Search().query(q).to_dict()
|
||||
if isinstance(idxnms, str):
|
||||
idxnms = idxnms.split(",")
|
||||
for i in range(3):
|
||||
try:
|
||||
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
|
||||
body=q,
|
||||
timeout=timeout,
|
||||
# search_type="dfs_query_then_fetch",
|
||||
track_total_hits=True,
|
||||
_source=src)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
return res
|
||||
except Exception as e:
|
||||
es_logger.error(
|
||||
"ES search exception: " +
|
||||
str(e) +
|
||||
"【Q】:" +
|
||||
str(q))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
raise e
|
||||
es_logger.error("ES search timeout for 3 times!")
|
||||
raise Exception("ES search timeout.")
|
||||
|
||||
def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
|
||||
for i in range(3):
|
||||
try:
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
|
||||
return res
|
||||
except ConnectionTimeout as e:
|
||||
es_logger.error("Timeout【Q】:" + sql)
|
||||
continue
|
||||
except Exception as e:
|
||||
raise e
|
||||
es_logger.error("ES search timeout for 3 times!")
|
||||
raise ConnectionTimeout()
|
||||
|
||||
|
||||
def get(self, doc_id, idxnm=None):
|
||||
for i in range(3):
|
||||
try:
|
||||
res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
|
||||
id=doc_id)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
return res
|
||||
except Exception as e:
|
||||
es_logger.error(
|
||||
"ES get exception: " +
|
||||
str(e) +
|
||||
"【Q】:" +
|
||||
doc_id)
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
raise e
|
||||
es_logger.error("ES search timeout for 3 times!")
|
||||
raise Exception("ES search timeout.")
|
||||
|
||||
def updateByQuery(self, q, d):
|
||||
ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
|
||||
scripts = ""
|
||||
for k, v in d.items():
|
||||
scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
|
||||
ubq = ubq.script(source=scripts, params=d)
|
||||
ubq = ubq.params(refresh=False)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for i in range(3):
|
||||
try:
|
||||
r = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return False
|
||||
|
||||
def updateScriptByQuery(self, q, scripts, idxnm=None):
|
||||
ubq = UpdateByQuery(
|
||||
index=self.idxnm if not idxnm else idxnm).using(
|
||||
self.es).query(q)
|
||||
ubq = ubq.script(source=scripts)
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for i in range(3):
|
||||
try:
|
||||
r = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return False
|
||||
|
||||
def deleteByQuery(self, query, idxnm=""):
|
||||
for i in range(3):
|
||||
try:
|
||||
r = self.es.delete_by_query(
|
||||
index=idxnm if idxnm else self.idxnm,
|
||||
refresh = True,
|
||||
body=Search().query(query).to_dict())
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery deleteByQuery: " +
|
||||
str(e) + "【Q】:" + str(query.to_dict()))
|
||||
if str(e).find("NotFoundError") > 0: return True
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def update(self, id, script, routing=None):
|
||||
for i in range(3):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.update(
|
||||
index=self.idxnm,
|
||||
id=id,
|
||||
body=json.dumps(
|
||||
script,
|
||||
ensure_ascii=False),
|
||||
doc_type="doc",
|
||||
routing=routing,
|
||||
refresh=False)
|
||||
else:
|
||||
r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
|
||||
routing=routing, refresh=False) # , doc_type="_doc")
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.error(
|
||||
"ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
||||
json.dumps(script, ensure_ascii=False))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def indexExist(self, idxnm):
|
||||
s = Index(idxnm if idxnm else self.idxnm, self.es)
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
s = Index(indexName, self.es)
|
||||
for i in range(3):
|
||||
try:
|
||||
return s.exists()
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery indexExist: " + str(e))
|
||||
doc_store_logger.error("ES indexExist: " + str(e))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def docExist(self, docid, idxnm=None):
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
|
||||
"""
|
||||
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||
"""
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||
assert "_id" not in condition
|
||||
s = Search()
|
||||
bqry = None
|
||||
vector_similarity_weight = 0.5
|
||||
for m in matchExprs:
|
||||
if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
|
||||
weights = m.fusion_params["weights"]
|
||||
vector_similarity_weight = float(weights.split(",")[1])
|
||||
for m in matchExprs:
|
||||
if isinstance(m, MatchTextExpr):
|
||||
minimum_should_match = "0%"
|
||||
if "minimum_should_match" in m.extra_options:
|
||||
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
|
||||
bqry = Q("bool",
|
||||
must=Q("query_string", fields=m.fields,
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match = minimum_should_match,
|
||||
boost=1),
|
||||
boost = 1.0 - vector_similarity_weight,
|
||||
)
|
||||
if condition:
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
assert(bqry is not None)
|
||||
similarity = 0.0
|
||||
if "similarity" in m.extra_options:
|
||||
similarity = m.extra_options["similarity"]
|
||||
s = s.knn(m.vector_column_name,
|
||||
m.topn,
|
||||
m.topn * 2,
|
||||
query_vector = list(m.embedding_data),
|
||||
filter = bqry.to_dict(),
|
||||
similarity = similarity,
|
||||
)
|
||||
if matchExprs:
|
||||
s.query = bqry
|
||||
for field in highlightFields:
|
||||
s = s.highlight(field)
|
||||
|
||||
if orderBy:
|
||||
orders = list()
|
||||
for field, order in orderBy.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
orders.append({field: {"order": order, "unmapped_type": "float",
|
||||
"mode": "avg", "numeric_type": "double"}})
|
||||
s = s.sort(*orders)
|
||||
|
||||
if limit > 0:
|
||||
s = s[offset:limit]
|
||||
q = s.to_dict()
|
||||
doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q))
|
||||
|
||||
for i in range(3):
|
||||
try:
|
||||
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
|
||||
id=docid)
|
||||
res = self.es.search(index=indexNames,
|
||||
body=q,
|
||||
timeout="600s",
|
||||
# search_type="dfs_query_then_fetch",
|
||||
track_total_hits=True,
|
||||
_source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
doc_store_logger.info("ESConnection.search res: " + str(res))
|
||||
return res
|
||||
except Exception as e:
|
||||
es_logger.error("ES Doc Exist: " + str(e))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
doc_store_logger.error(
|
||||
"ES search exception: " +
|
||||
str(e) +
|
||||
"\n[Q]: " +
|
||||
str(q))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
raise e
|
||||
doc_store_logger.error("ES search timeout for 3 times!")
|
||||
raise Exception("ES search timeout.")
|
||||
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
for i in range(3):
|
||||
try:
|
||||
res = self.es.get(index=(indexName),
|
||||
id=chunkId, source=True,)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
if not res.get("found"):
|
||||
return None
|
||||
chunk = res["_source"]
|
||||
chunk["id"] = chunkId
|
||||
return chunk
|
||||
except Exception as e:
|
||||
doc_store_logger.error(
|
||||
"ES get exception: " +
|
||||
str(e) +
|
||||
"[Q]: " +
|
||||
chunkId)
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
raise e
|
||||
doc_store_logger.error("ES search timeout for 3 times!")
|
||||
raise Exception("ES search timeout.")
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
||||
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
||||
operations = []
|
||||
for d in documents:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
d_copy = copy.deepcopy(d)
|
||||
meta_id = d_copy["id"]
|
||||
del d_copy["id"]
|
||||
operations.append(
|
||||
{"index": {"_index": indexName, "_id": meta_id}})
|
||||
operations.append(d_copy)
|
||||
|
||||
res = []
|
||||
for _ in range(100):
|
||||
try:
|
||||
r = self.es.bulk(index=(indexName), operations=operations,
|
||||
refresh=False, timeout="600s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for item in r["items"]:
|
||||
for action in ["create", "delete", "index", "update"]:
|
||||
if action in item and "error" in item[action]:
|
||||
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
||||
return res
|
||||
except Exception as e:
|
||||
doc_store_logger.warning("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
return res
|
||||
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
doc = copy.deepcopy(newValue)
|
||||
del doc['id']
|
||||
if "id" in condition and isinstance(condition["id"], str):
|
||||
# update specific single document
|
||||
chunkId = condition["id"]
|
||||
for i in range(3):
|
||||
try:
|
||||
self.es.update(index=indexName, id=chunkId, doc=doc)
|
||||
return True
|
||||
except Exception as e:
|
||||
doc_store_logger.error(
|
||||
"ES update exception: " + str(e) + " id:" + str(id) +
|
||||
json.dumps(newValue, ensure_ascii=False))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
else:
|
||||
# update unspecific maybe-multiple documents
|
||||
bqry = Q("bool")
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
scripts = []
|
||||
for k, v in newValue.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if isinstance(v, str):
|
||||
scripts.append(f"ctx._source.{k} = '{v}'")
|
||||
elif isinstance(v, int):
|
||||
scripts.append(f"ctx._source.{k} = {v}")
|
||||
else:
|
||||
raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||
ubq = UpdateByQuery(
|
||||
index=indexName).using(
|
||||
self.es).query(bqry)
|
||||
ubq = ubq.script(source="; ".join(scripts))
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for i in range(3):
|
||||
try:
|
||||
_ = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
doc_store_logger.error("ES update exception: " +
|
||||
str(e) + "[Q]:" + str(bqry.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
return False
|
||||
|
||||
def createIdx(self, idxnm, mapping):
|
||||
try:
|
||||
if elasticsearch.__version__[0] < 8:
|
||||
return self.es.indices.create(idxnm, body=mapping)
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=idxnm,
|
||||
settings=mapping["settings"],
|
||||
mappings=mapping["mappings"])
|
||||
except Exception as e:
|
||||
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
qry = None
|
||||
assert "_id" not in condition
|
||||
if "id" in condition:
|
||||
chunk_ids = condition["id"]
|
||||
if not isinstance(chunk_ids, list):
|
||||
chunk_ids = [chunk_ids]
|
||||
qry = Q("ids", values=chunk_ids)
|
||||
else:
|
||||
qry = Q("bool")
|
||||
for k, v in condition.items():
|
||||
if isinstance(v, list):
|
||||
qry.must.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
qry.must.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception("Condition value must be int, str or list.")
|
||||
doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
|
||||
for _ in range(10):
|
||||
try:
|
||||
res = self.es.delete_by_query(
|
||||
index=indexName,
|
||||
body = Search().query(qry).to_dict(),
|
||||
refresh=True)
|
||||
return res["deleted"]
|
||||
except Exception as e:
|
||||
doc_store_logger.warning("Fail to delete: " + str(filter) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return 0
|
||||
return 0
|
||||
|
||||
def deleteIdx(self, idxnm):
|
||||
try:
|
||||
return self.es.indices.delete(idxnm, allow_no_indices=True)
|
||||
except Exception as e:
|
||||
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
def getTotal(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getDocIds(self, res):
|
||||
def getChunkIds(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def getSource(self, res):
|
||||
def __getSource(self, res):
|
||||
rr = []
|
||||
for d in res["hits"]["hits"]:
|
||||
d["_source"]["id"] = d["_id"]
|
||||
@ -425,40 +352,89 @@ class ESConnection:
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
||||
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
||||
for _ in range(100):
|
||||
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
for d in self.__getSource(res):
|
||||
m = {n: d.get(n) for n in fields if d.get(n) is not None}
|
||||
for n, v in m.items():
|
||||
if isinstance(v, list):
|
||||
m[n] = v
|
||||
continue
|
||||
if not isinstance(v, str):
|
||||
m[n] = str(m[n])
|
||||
# if n.find("tks") > 0:
|
||||
# m[n] = rmSpace(m[n])
|
||||
|
||||
if m:
|
||||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
if not hlts:
|
||||
continue
|
||||
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
||||
if not is_english(txt.split(" ")):
|
||||
ans[d["_id"]] = txt
|
||||
continue
|
||||
|
||||
txt = d["_source"][fieldnm]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
|
||||
continue
|
||||
txts.append(t)
|
||||
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
||||
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
bkts = res["aggregations"][agg_field]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
doc_store_logger.info(f"ESConnection.sql get sql: {sql}")
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
replaces = []
|
||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||
fld, v = r.group(1), r.group(3)
|
||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||
replaces.append(
|
||||
("{}{}'{}'".format(
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
doc_store_logger.info(f"ESConnection.sql to es: {sql}")
|
||||
|
||||
for i in range(3):
|
||||
try:
|
||||
page = self.es.search(
|
||||
index=self.idxnm,
|
||||
scroll=scroll_time,
|
||||
size=pagesize,
|
||||
body=q,
|
||||
_source=None
|
||||
)
|
||||
break
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql)
|
||||
continue
|
||||
except Exception as e:
|
||||
es_logger.error("ES scrolling fail. " + str(e))
|
||||
time.sleep(3)
|
||||
|
||||
sid = page['_scroll_id']
|
||||
scroll_size = page['hits']['total']["value"]
|
||||
es_logger.info("[TOTAL]%d" % scroll_size)
|
||||
# Start scrolling
|
||||
while scroll_size > 0:
|
||||
yield page["hits"]["hits"]
|
||||
for _ in range(100):
|
||||
try:
|
||||
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
|
||||
break
|
||||
except Exception as e:
|
||||
es_logger.error("ES scrolling fail. " + str(e))
|
||||
time.sleep(3)
|
||||
|
||||
# Update the scroll ID
|
||||
sid = page['_scroll_id']
|
||||
# Get the number of results that we returned in the last scroll
|
||||
scroll_size = len(page['hits']['hits'])
|
||||
|
||||
|
||||
ELASTICSEARCH = ESConnection()
|
||||
doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e))
|
||||
return None
|
||||
doc_store_logger.error("ESConnection.sql timeout for 3 times!")
|
||||
return None
|
||||
|
436
rag/utils/infinity_conn.py
Normal file
436
rag/utils/infinity_conn.py
Normal file
@ -0,0 +1,436 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from typing import List, Dict
|
||||
import infinity
|
||||
from infinity.common import ConflictType, InfinityException
|
||||
from infinity.index import IndexInfo, IndexType
|
||||
from infinity.connection_pool import ConnectionPool
|
||||
from rag import settings
|
||||
from rag.settings import doc_store_logger
|
||||
from rag.utils import singleton
|
||||
import polars as pl
|
||||
from polars.series.series import Series
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
from rag.utils.doc_store_conn import (
|
||||
DocStoreConnection,
|
||||
MatchExpr,
|
||||
MatchTextExpr,
|
||||
MatchDenseExpr,
|
||||
FusionExpr,
|
||||
OrderByExpr,
|
||||
)
|
||||
|
||||
|
||||
def equivalent_condition_to_str(condition: dict) -> str:
|
||||
assert "_id" not in condition
|
||||
cond = list()
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
inCond.append(f"'{item}'")
|
||||
else:
|
||||
inCond.append(str(item))
|
||||
if inCond:
|
||||
strInCond = ", ".join(inCond)
|
||||
strInCond = f"{k} IN ({strInCond})"
|
||||
cond.append(strInCond)
|
||||
elif isinstance(v, str):
|
||||
cond.append(f"{k}='{v}'")
|
||||
else:
|
||||
cond.append(f"{k}={str(v)}")
|
||||
return " AND ".join(cond)
|
||||
|
||||
|
||||
@singleton
|
||||
class InfinityConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||
infinity_uri = settings.INFINITY["uri"]
|
||||
if ":" in infinity_uri:
|
||||
host, port = infinity_uri.split(":")
|
||||
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||
self.connPool = ConnectionPool(infinity_uri)
|
||||
doc_store_logger.info(f"Connected to infinity {infinity_uri}.")
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
return "infinity"
|
||||
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
|
||||
"""
|
||||
inf_conn = self.connPool.get_conn()
|
||||
res = infinity.show_current_node()
|
||||
self.connPool.release_conn(inf_conn)
|
||||
color = "green" if res.error_code == 0 else "red"
|
||||
res2 = {
|
||||
"type": "infinity",
|
||||
"status": f"{res.role} {color}",
|
||||
"error": res.error_msg,
|
||||
}
|
||||
return res2
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
|
||||
fp_mapping = os.path.join(
|
||||
get_project_base_directory(), "conf", "infinity_mapping.json"
|
||||
)
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
vector_name = f"q_{vectorSize}_vec"
|
||||
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
|
||||
inf_table = inf_db.create_table(
|
||||
table_name,
|
||||
schema,
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
inf_table.create_index(
|
||||
"q_vec_idx",
|
||||
IndexInfo(
|
||||
vector_name,
|
||||
IndexType.Hnsw,
|
||||
{
|
||||
"M": "16",
|
||||
"ef_construction": "50",
|
||||
"metric": "cosine",
|
||||
"encode": "lvq",
|
||||
},
|
||||
),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
text_suffix = ["_tks", "_ltks", "_kwd"]
|
||||
for field_name, field_info in schema.items():
|
||||
if field_info["type"] != "varchar":
|
||||
continue
|
||||
for suffix in text_suffix:
|
||||
if field_name.endswith(suffix):
|
||||
inf_table.create_index(
|
||||
f"text_idx_{field_name}",
|
||||
IndexInfo(
|
||||
field_name, IndexType.FullText, {"ANALYZER": "standard"}
|
||||
),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
break
|
||||
self.connPool.release_conn(inf_conn)
|
||||
doc_store_logger.info(
|
||||
f"INFINITY created table {table_name}, vector size {vectorSize}"
|
||||
)
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
db_instance.drop_table(table_name, ConflictType.Ignore)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
doc_store_logger.info(f"INFINITY dropped table {table_name}")
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
_ = db_instance.get_table(table_name)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
except Exception as e:
|
||||
doc_store_logger.error("INFINITY indexExist: " + str(e))
|
||||
return False
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str|list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
) -> list[dict] | pl.DataFrame:
|
||||
"""
|
||||
TODO: Infinity doesn't provide highlight
|
||||
"""
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
table_list = list()
|
||||
if "id" not in selectFields:
|
||||
selectFields.append("id")
|
||||
|
||||
# Prepare expressions common to all tables
|
||||
filter_cond = ""
|
||||
filter_fulltext = ""
|
||||
if condition:
|
||||
filter_cond = equivalent_condition_to_str(condition)
|
||||
for matchExpr in matchExprs:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_cond})
|
||||
fields = ",".join(matchExpr.fields)
|
||||
filter_fulltext = (
|
||||
f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
|
||||
)
|
||||
if len(filter_cond) != 0:
|
||||
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
|
||||
# doc_store_logger.info(f"filter_fulltext: {filter_fulltext}")
|
||||
minimum_should_match = "0%"
|
||||
if "minimum_should_match" in matchExpr.extra_options:
|
||||
minimum_should_match = (
|
||||
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
|
||||
+ "%"
|
||||
)
|
||||
matchExpr.extra_options.update(
|
||||
{"minimum_should_match": minimum_should_match}
|
||||
)
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
if orderBy.fields:
|
||||
order_by_expr_list = list()
|
||||
for order_field in orderBy.fields:
|
||||
order_by_expr_list.append((order_field[0], order_field[1] == 0))
|
||||
|
||||
# Scatter search tables and gather the results
|
||||
for indexName in indexNames:
|
||||
for knowledgebaseId in knowledgebaseIds:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
continue
|
||||
table_list.append(table_name)
|
||||
builder = table_instance.output(selectFields)
|
||||
for matchExpr in matchExprs:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
fields = ",".join(matchExpr.fields)
|
||||
builder = builder.match_text(
|
||||
fields,
|
||||
matchExpr.matching_text,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options,
|
||||
)
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
builder = builder.match_dense(
|
||||
matchExpr.vector_column_name,
|
||||
matchExpr.embedding_data,
|
||||
matchExpr.embedding_data_type,
|
||||
matchExpr.distance_type,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options,
|
||||
)
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
builder = builder.fusion(
|
||||
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
|
||||
)
|
||||
if orderBy.fields:
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(offset).limit(limit)
|
||||
kb_res = builder.to_pl()
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = pl.concat(df_list)
|
||||
doc_store_logger.info("INFINITY search tables: " + str(table_list))
|
||||
return res
|
||||
|
||||
def get(
|
||||
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
|
||||
) -> dict | None:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
assert isinstance(knowledgebaseIds, list)
|
||||
for knowledgebaseId in knowledgebaseIds:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = pl.concat(df_list)
|
||||
res_fields = self.getFields(res, res.columns)
|
||||
return res_fields.get(chunkId, None)
|
||||
|
||||
def insert(
|
||||
self, documents: list[dict], indexName: str, knowledgebaseId: str
|
||||
) -> list[str]:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except InfinityException as e:
|
||||
# src/common/status.cppm, kTableNotExist = 3022
|
||||
if e.error_code != 3022:
|
||||
raise
|
||||
vector_size = 0
|
||||
patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
|
||||
for k in documents[0].keys():
|
||||
m = patt.match(k)
|
||||
if m:
|
||||
vector_size = int(m.group("vector_size"))
|
||||
break
|
||||
if vector_size == 0:
|
||||
raise ValueError("Cannot infer vector size from documents")
|
||||
self.createIdx(indexName, knowledgebaseId, vector_size)
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
for d in documents:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
for k, v in d.items():
|
||||
if k.endswith("_kwd") and isinstance(v, list):
|
||||
d[k] = " ".join(v)
|
||||
ids = [f"'{d["id"]}'" for d in documents]
|
||||
str_ids = ", ".join(ids)
|
||||
str_filter = f"id IN ({str_ids})"
|
||||
table_instance.delete(str_filter)
|
||||
# for doc in documents:
|
||||
# doc_store_logger.info(f"insert position_list: {doc['position_list']}")
|
||||
# doc_store_logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
|
||||
table_instance.insert(documents)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
doc_store_logger.info(f"inserted into {table_name} {str_ids}.")
|
||||
return []
|
||||
|
||||
def update(
|
||||
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
|
||||
) -> bool:
|
||||
# if 'position_list' in newValue:
|
||||
# doc_store_logger.info(f"update position_list: {newValue['position_list']}")
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
filter = equivalent_condition_to_str(condition)
|
||||
for k, v in newValue.items():
|
||||
if k.endswith("_kwd") and isinstance(v, list):
|
||||
newValue[k] = " ".join(v)
|
||||
table_instance.update(filter, newValue)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
filter = equivalent_condition_to_str(condition)
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
doc_store_logger.warning(
|
||||
f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
|
||||
)
|
||||
return 0
|
||||
res = table_instance.delete(filter)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res.deleted_rows
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res):
|
||||
return len(res)
|
||||
|
||||
def getChunkIds(self, res):
|
||||
return list(res["id"])
|
||||
|
||||
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
m = {"id": id}
|
||||
for fieldnm in fields:
|
||||
if fieldnm not in res:
|
||||
m[fieldnm] = None
|
||||
continue
|
||||
v = res[fieldnm][i]
|
||||
if isinstance(v, Series):
|
||||
v = list(v)
|
||||
elif fieldnm == "important_kwd":
|
||||
assert isinstance(v, str)
|
||||
v = v.split(" ")
|
||||
else:
|
||||
if not isinstance(v, str):
|
||||
v = str(v)
|
||||
# if fieldnm.endswith("_tks"):
|
||||
# v = rmSpace(v)
|
||||
m[fieldnm] = v
|
||||
res_fields[id] = m
|
||||
return res_fields
|
||||
|
||||
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
||||
ans = {}
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
txt = res[fieldnm][i]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(
|
||||
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
|
||||
% re.escape(w),
|
||||
r"\1<em>\2</em>\3",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
if not re.search(
|
||||
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
|
||||
):
|
||||
continue
|
||||
txts.append(t)
|
||||
ans[id] = "...".join(txts)
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
"""
|
||||
TODO: Infinity doesn't provide aggregation
|
||||
"""
|
||||
return list()
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
raise NotImplementedError("Not implemented")
|
@ -50,8 +50,8 @@ class Document(Base):
|
||||
return res.content
|
||||
|
||||
|
||||
def list_chunks(self,page=1, page_size=30, keywords="", id:str=None):
|
||||
data={"keywords": keywords,"page":page,"page_size":page_size,"id":id}
|
||||
def list_chunks(self,page=1, page_size=30, keywords=""):
|
||||
data={"keywords": keywords,"page":page,"page_size":page_size}
|
||||
res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
|
@ -126,6 +126,7 @@ def test_delete_chunk_with_success(get_api_key_fixture):
|
||||
docs = ds.upload_documents(documents)
|
||||
doc = docs[0]
|
||||
chunk = doc.add_chunk(content="This is a chunk addition test")
|
||||
sleep(5)
|
||||
doc.delete_chunks([chunk.id])
|
||||
|
||||
|
||||
@ -146,6 +147,8 @@ def test_update_chunk_content(get_api_key_fixture):
|
||||
docs = ds.upload_documents(documents)
|
||||
doc = docs[0]
|
||||
chunk = doc.add_chunk(content="This is a chunk addition test")
|
||||
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
|
||||
sleep(3)
|
||||
chunk.update({"content":"This is a updated content"})
|
||||
|
||||
def test_update_chunk_available(get_api_key_fixture):
|
||||
@ -165,7 +168,9 @@ def test_update_chunk_available(get_api_key_fixture):
|
||||
docs = ds.upload_documents(documents)
|
||||
doc = docs[0]
|
||||
chunk = doc.add_chunk(content="This is a chunk addition test")
|
||||
chunk.update({"available":False})
|
||||
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
|
||||
sleep(3)
|
||||
chunk.update({"available":0})
|
||||
|
||||
|
||||
def test_retrieve_chunks(get_api_key_fixture):
|
||||
|
Loading…
x
Reference in New Issue
Block a user