mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 22:05:56 +08:00
Integration with Infinity (#2894)
### What problem does this PR solve? Integration with Infinity - Replaced ELASTICSEARCH with dataStoreConn - Renamed deleteByQuery with delete - Renamed bulk to upsertBulk - getHighlight, getAggregation - Fix KGSearch.search - Moved Dealer.sql_retrieval to es_conn.py ### Type of change - [x] Refactoring
This commit is contained in:
parent
00b6000b76
commit
f4c52371ab
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -78,7 +78,7 @@ jobs:
|
|||||||
echo "Waiting for service to be available..."
|
echo "Waiting for service to be available..."
|
||||||
sleep 5
|
sleep 5
|
||||||
done
|
done
|
||||||
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
|
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest --tb=short t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
|
||||||
|
|
||||||
- name: Stop ragflow:dev
|
- name: Stop ragflow:dev
|
||||||
if: always() # always run this step even if previous steps failed
|
if: always() # always run this step even if previous steps failed
|
||||||
|
@ -285,7 +285,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|||||||
git clone https://github.com/infiniflow/ragflow.git
|
git clone https://github.com/infiniflow/ragflow.git
|
||||||
cd ragflow/
|
cd ragflow/
|
||||||
export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
|
export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||||
~/.local/bin/poetry install --sync --no-root # install RAGFlow dependent python modules
|
~/.local/bin/poetry install --sync --no-root --with=full # install RAGFlow dependent python modules
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
|
3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
|
||||||
@ -295,7 +295,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|||||||
|
|
||||||
Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
||||||
```
|
```
|
||||||
127.0.0.1 es01 mysql minio redis
|
127.0.0.1 es01 infinity mysql minio redis
|
||||||
```
|
```
|
||||||
In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|||||||
|
|
||||||
`/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
|
`/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
|
||||||
```
|
```
|
||||||
127.0.0.1 es01 mysql minio redis
|
127.0.0.1 es01 infinity mysql minio redis
|
||||||
```
|
```
|
||||||
**docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).
|
**docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|||||||
|
|
||||||
`/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
|
`/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
|
||||||
```
|
```
|
||||||
127.0.0.1 es01 mysql minio redis
|
127.0.0.1 es01 infinity mysql minio redis
|
||||||
```
|
```
|
||||||
**docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).
|
**docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
|
|||||||
|
|
||||||
在 `/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`:
|
在 `/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`:
|
||||||
```
|
```
|
||||||
127.0.0.1 es01 mysql minio redis
|
127.0.0.1 es01 infinity mysql minio redis
|
||||||
```
|
```
|
||||||
在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`,es 端口更新为 `1200`。
|
在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`,es 端口更新为 `1200`。
|
||||||
|
|
||||||
|
@ -529,13 +529,14 @@ def list_chunks():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message="Can't find doc_name or doc_id"
|
data=False, message="Can't find doc_name or doc_id"
|
||||||
)
|
)
|
||||||
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
|
|
||||||
res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id)
|
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
||||||
res = [
|
res = [
|
||||||
{
|
{
|
||||||
"content": res_item["content_with_weight"],
|
"content": res_item["content_with_weight"],
|
||||||
"doc_name": res_item["docnm_kwd"],
|
"doc_name": res_item["docnm_kwd"],
|
||||||
"img_id": res_item["img_id"]
|
"image_id": res_item["img_id"]
|
||||||
} for res_item in res
|
} for res_item in res
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -18,12 +18,10 @@ import json
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from elasticsearch_dsl import Q
|
|
||||||
|
|
||||||
from api.db.services.dialog_service import keyword_extraction
|
from api.db.services.dialog_service import keyword_extraction
|
||||||
from rag.app.qa import rmPrefix, beAdoc
|
from rag.app.qa import rmPrefix, beAdoc
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
@ -31,12 +29,11 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.settings import RetCode, retrievaler, kg_retrievaler
|
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['POST'])
|
@manager.route('/list', methods=['POST'])
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id")
|
@validate_request("doc_id")
|
||||||
@ -53,12 +50,13 @@ def list_chunk():
|
|||||||
e, doc = DocumentService.get_by_id(doc_id)
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
query = {
|
query = {
|
||||||
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
|
||||||
}
|
}
|
||||||
if "available_int" in req:
|
if "available_int" in req:
|
||||||
query["available_int"] = int(req["available_int"])
|
query["available_int"] = int(req["available_int"])
|
||||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -69,16 +67,12 @@ def list_chunk():
|
|||||||
"doc_id": sres.field[id]["doc_id"],
|
"doc_id": sres.field[id]["doc_id"],
|
||||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||||
"img_id": sres.field[id].get("img_id", ""),
|
"image_id": sres.field[id].get("img_id", ""),
|
||||||
"available_int": sres.field[id].get("available_int", 1),
|
"available_int": sres.field[id].get("available_int", 1),
|
||||||
"positions": sres.field[id].get("position_int", "").split("\t")
|
"positions": json.loads(sres.field[id].get("position_list", "[]")),
|
||||||
}
|
}
|
||||||
if len(d["positions"]) % 5 == 0:
|
assert isinstance(d["positions"], list)
|
||||||
poss = []
|
assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
||||||
for i in range(0, len(d["positions"]), 5):
|
|
||||||
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
|
||||||
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
|
||||||
d["positions"] = poss
|
|
||||||
res["chunks"].append(d)
|
res["chunks"].append(d)
|
||||||
return get_json_result(data=res)
|
return get_json_result(data=res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -96,22 +90,20 @@ def get():
|
|||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
if not tenants:
|
if not tenants:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
res = ELASTICSEARCH.get(
|
tenant_id = tenants[0].tenant_id
|
||||||
chunk_id, search.index_name(
|
|
||||||
tenants[0].tenant_id))
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
if not res.get("found"):
|
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||||
|
if chunk is None:
|
||||||
return server_error_response("Chunk not found")
|
return server_error_response("Chunk not found")
|
||||||
id = res["_id"]
|
|
||||||
res = res["_source"]
|
|
||||||
res["chunk_id"] = id
|
|
||||||
k = []
|
k = []
|
||||||
for n in res.keys():
|
for n in chunk.keys():
|
||||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||||
k.append(n)
|
k.append(n)
|
||||||
for n in k:
|
for n in k:
|
||||||
del res[n]
|
del chunk[n]
|
||||||
|
|
||||||
return get_json_result(data=res)
|
return get_json_result(data=chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("NotFoundError") >= 0:
|
if str(e).find("NotFoundError") >= 0:
|
||||||
return get_json_result(data=False, message='Chunk not found!',
|
return get_json_result(data=False, message='Chunk not found!',
|
||||||
@ -162,7 +154,7 @@ def set():
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -174,11 +166,11 @@ def set():
|
|||||||
def switch():
|
def switch():
|
||||||
req = request.json
|
req = request.json
|
||||||
try:
|
try:
|
||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not e:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
|
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
||||||
search.index_name(tenant_id)):
|
search.index_name(doc.tenant_id), doc.kb_id):
|
||||||
return get_data_error_result(message="Index updating failure")
|
return get_data_error_result(message="Index updating failure")
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -191,12 +183,11 @@ def switch():
|
|||||||
def rm():
|
def rm():
|
||||||
req = request.json
|
req = request.json
|
||||||
try:
|
try:
|
||||||
if not ELASTICSEARCH.deleteByQuery(
|
|
||||||
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
|
|
||||||
return get_data_error_result(message="Index updating failure")
|
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
|
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
||||||
|
return get_data_error_result(message="Index updating failure")
|
||||||
deleted_chunk_ids = req["chunk_ids"]
|
deleted_chunk_ids = req["chunk_ids"]
|
||||||
chunk_number = len(deleted_chunk_ids)
|
chunk_number = len(deleted_chunk_ids)
|
||||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
||||||
@ -239,7 +230,7 @@ def create():
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc.id, doc.kb_id, c, 1, 0)
|
doc.id, doc.kb_id, c, 1, 0)
|
||||||
@ -256,8 +247,9 @@ def retrieval_test():
|
|||||||
page = int(req.get("page", 1))
|
page = int(req.get("page", 1))
|
||||||
size = int(req.get("size", 30))
|
size = int(req.get("size", 30))
|
||||||
question = req["question"]
|
question = req["question"]
|
||||||
kb_id = req["kb_id"]
|
kb_ids = req["kb_id"]
|
||||||
if isinstance(kb_id, str): kb_id = [kb_id]
|
if isinstance(kb_ids, str):
|
||||||
|
kb_ids = [kb_ids]
|
||||||
doc_ids = req.get("doc_ids", [])
|
doc_ids = req.get("doc_ids", [])
|
||||||
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
similarity_threshold = float(req.get("similarity_threshold", 0.0))
|
||||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||||
@ -265,17 +257,17 @@ def retrieval_test():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
for kid in kb_id:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(
|
if KnowledgebaseService.query(
|
||||||
tenant_id=tenant.tenant_id, id=kid):
|
tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
|
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
@ -290,7 +282,7 @@ def retrieval_test():
|
|||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||||
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
|
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
||||||
similarity_threshold, vector_similarity_weight, top,
|
similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
||||||
for c in ranks["chunks"]:
|
for c in ranks["chunks"]:
|
||||||
@ -309,12 +301,16 @@ def retrieval_test():
|
|||||||
@login_required
|
@login_required
|
||||||
def knowledge_graph():
|
def knowledge_graph():
|
||||||
doc_id = request.args["doc_id"]
|
doc_id = request.args["doc_id"]
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Document not found!")
|
||||||
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
req = {
|
req = {
|
||||||
"doc_ids":[doc_id],
|
"doc_ids":[doc_id],
|
||||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||||
}
|
}
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids, doc.kb_id)
|
||||||
sres = retrievaler.search(req, search.index_name(tenant_id))
|
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
for id in sres.ids[:2]:
|
for id in sres.ids[:2]:
|
||||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||||
|
@ -17,7 +17,6 @@ import pathlib
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
from elasticsearch_dsl import Q
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
@ -27,14 +26,13 @@ from api.db.services.file_service import FileService
|
|||||||
from api.db.services.task_service import TaskService, queue_tasks
|
from api.db.services.task_service import TaskService, queue_tasks
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import FileType, TaskStatus, ParserType, FileSource
|
from api.db import FileType, TaskStatus, ParserType, FileSource
|
||||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||||
from api.settings import RetCode
|
from api.settings import RetCode, docStoreConn
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from api.utils.file_utils import filename_type, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
@ -275,18 +273,8 @@ def change_status():
|
|||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Database error (Document update)!")
|
message="Database error (Document update)!")
|
||||||
|
|
||||||
if str(req["status"]) == "0":
|
status = int(req["status"])
|
||||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||||
scripts="ctx._source.available_int=0;",
|
|
||||||
idxnm=search.index_name(
|
|
||||||
kb.tenant_id)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
|
||||||
scripts="ctx._source.available_int=1;",
|
|
||||||
idxnm=search.index_name(
|
|
||||||
kb.tenant_id)
|
|
||||||
)
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -365,8 +353,11 @@ def run():
|
|||||||
tenant_id = DocumentService.get_tenant_id(id)
|
tenant_id = DocumentService.get_tenant_id(id)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
ELASTICSEARCH.deleteByQuery(
|
e, doc = DocumentService.get_by_id(id)
|
||||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
|
if not e:
|
||||||
|
return get_data_error_result(message="Document not found!")
|
||||||
|
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
|
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
@ -490,8 +481,8 @@ def change_parser():
|
|||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
ELASTICSEARCH.deleteByQuery(
|
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -28,6 +28,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
from api.settings import RetCode
|
from api.settings import RetCode
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
|
from api.settings import docStoreConn
|
||||||
|
from rag.nlp import search
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/create', methods=['post'])
|
@manager.route('/create', methods=['post'])
|
||||||
@ -166,6 +168,9 @@ def rm():
|
|||||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Database error (Knowledgebase removal)!")
|
message="Database error (Knowledgebase removal)!")
|
||||||
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
for tenant in tenants:
|
||||||
|
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
@ -30,7 +30,6 @@ from api.db.services.task_service import TaskService, queue_tasks
|
|||||||
from api.utils.api_utils import server_error_response
|
from api.utils.api_utils import server_error_response
|
||||||
from api.utils.api_utils import get_result, get_error_data_result
|
from api.utils.api_utils import get_result, get_error_data_result
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from elasticsearch_dsl import Q
|
|
||||||
from flask import request, send_file
|
from flask import request, send_file
|
||||||
from api.db import FileSource, TaskStatus, FileType
|
from api.db import FileSource, TaskStatus, FileType
|
||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
@ -42,7 +41,7 @@ from api.settings import RetCode, retrievaler
|
|||||||
from api.utils.api_utils import construct_json_result, get_parser_config
|
from api.utils.api_utils import construct_json_result, get_parser_config
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
from api.settings import docStoreConn
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -293,9 +292,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
)
|
)
|
||||||
if not e:
|
if not e:
|
||||||
return get_error_data_result(message="Document not found!")
|
return get_error_data_result(message="Document not found!")
|
||||||
ELASTICSEARCH.deleteByQuery(
|
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
@ -647,9 +644,7 @@ def parse(tenant_id, dataset_id):
|
|||||||
info["chunk_num"] = 0
|
info["chunk_num"] = 0
|
||||||
info["token_num"] = 0
|
info["token_num"] = 0
|
||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
ELASTICSEARCH.deleteByQuery(
|
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
|
|
||||||
)
|
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
e, doc = DocumentService.get_by_id(id)
|
e, doc = DocumentService.get_by_id(id)
|
||||||
doc = doc.to_dict()
|
doc = doc.to_dict()
|
||||||
@ -713,9 +708,7 @@ def stop_parsing(tenant_id, dataset_id):
|
|||||||
)
|
)
|
||||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
ELASTICSEARCH.deleteByQuery(
|
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||||
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
|
|
||||||
)
|
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
|
|
||||||
@ -812,7 +805,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
"question": question,
|
"question": question,
|
||||||
"sort": True,
|
"sort": True,
|
||||||
}
|
}
|
||||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
|
||||||
key_mapping = {
|
key_mapping = {
|
||||||
"chunk_num": "chunk_count",
|
"chunk_num": "chunk_count",
|
||||||
"kb_id": "dataset_id",
|
"kb_id": "dataset_id",
|
||||||
@ -833,51 +825,56 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
renamed_doc[new_key] = value
|
renamed_doc[new_key] = value
|
||||||
if key == "run":
|
if key == "run":
|
||||||
renamed_doc["run"] = run_mapping.get(str(value))
|
renamed_doc["run"] = run_mapping.get(str(value))
|
||||||
res = {"total": sres.total, "chunks": [], "doc": renamed_doc}
|
|
||||||
origin_chunks = []
|
|
||||||
sign = 0
|
|
||||||
for id in sres.ids:
|
|
||||||
d = {
|
|
||||||
"chunk_id": id,
|
|
||||||
"content_with_weight": (
|
|
||||||
rmSpace(sres.highlight[id])
|
|
||||||
if question and id in sres.highlight
|
|
||||||
else sres.field[id].get("content_with_weight", "")
|
|
||||||
),
|
|
||||||
"doc_id": sres.field[id]["doc_id"],
|
|
||||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
|
||||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
|
||||||
"img_id": sres.field[id].get("img_id", ""),
|
|
||||||
"available_int": sres.field[id].get("available_int", 1),
|
|
||||||
"positions": sres.field[id].get("position_int", "").split("\t"),
|
|
||||||
}
|
|
||||||
if len(d["positions"]) % 5 == 0:
|
|
||||||
poss = []
|
|
||||||
for i in range(0, len(d["positions"]), 5):
|
|
||||||
poss.append(
|
|
||||||
[
|
|
||||||
float(d["positions"][i]),
|
|
||||||
float(d["positions"][i + 1]),
|
|
||||||
float(d["positions"][i + 2]),
|
|
||||||
float(d["positions"][i + 3]),
|
|
||||||
float(d["positions"][i + 4]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
d["positions"] = poss
|
|
||||||
|
|
||||||
origin_chunks.append(d)
|
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
||||||
|
origin_chunks = []
|
||||||
|
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||||
|
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||||
|
res["total"] = sres.total
|
||||||
|
sign = 0
|
||||||
|
for id in sres.ids:
|
||||||
|
d = {
|
||||||
|
"id": id,
|
||||||
|
"content_with_weight": (
|
||||||
|
rmSpace(sres.highlight[id])
|
||||||
|
if question and id in sres.highlight
|
||||||
|
else sres.field[id].get("content_with_weight", "")
|
||||||
|
),
|
||||||
|
"doc_id": sres.field[id]["doc_id"],
|
||||||
|
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||||
|
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||||
|
"img_id": sres.field[id].get("img_id", ""),
|
||||||
|
"available_int": sres.field[id].get("available_int", 1),
|
||||||
|
"positions": sres.field[id].get("position_int", "").split("\t"),
|
||||||
|
}
|
||||||
|
if len(d["positions"]) % 5 == 0:
|
||||||
|
poss = []
|
||||||
|
for i in range(0, len(d["positions"]), 5):
|
||||||
|
poss.append(
|
||||||
|
[
|
||||||
|
float(d["positions"][i]),
|
||||||
|
float(d["positions"][i + 1]),
|
||||||
|
float(d["positions"][i + 2]),
|
||||||
|
float(d["positions"][i + 3]),
|
||||||
|
float(d["positions"][i + 4]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
d["positions"] = poss
|
||||||
|
|
||||||
|
origin_chunks.append(d)
|
||||||
|
if req.get("id"):
|
||||||
|
if req.get("id") == id:
|
||||||
|
origin_chunks.clear()
|
||||||
|
origin_chunks.append(d)
|
||||||
|
sign = 1
|
||||||
|
break
|
||||||
if req.get("id"):
|
if req.get("id"):
|
||||||
if req.get("id") == id:
|
if sign == 0:
|
||||||
origin_chunks.clear()
|
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
|
||||||
origin_chunks.append(d)
|
|
||||||
sign = 1
|
|
||||||
break
|
|
||||||
if req.get("id"):
|
|
||||||
if sign == 0:
|
|
||||||
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
|
|
||||||
for chunk in origin_chunks:
|
for chunk in origin_chunks:
|
||||||
key_mapping = {
|
key_mapping = {
|
||||||
"chunk_id": "id",
|
"id": "id",
|
||||||
"content_with_weight": "content",
|
"content_with_weight": "content",
|
||||||
"doc_id": "document_id",
|
"doc_id": "document_id",
|
||||||
"important_kwd": "important_keywords",
|
"important_kwd": "important_keywords",
|
||||||
@ -996,9 +993,9 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|||||||
)
|
)
|
||||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||||
d["kb_id"] = [doc.kb_id]
|
d["kb_id"] = dataset_id
|
||||||
d["docnm_kwd"] = doc.name
|
d["docnm_kwd"] = doc.name
|
||||||
d["doc_id"] = doc.id
|
d["doc_id"] = document_id
|
||||||
embd_id = DocumentService.get_embd_id(document_id)
|
embd_id = DocumentService.get_embd_id(document_id)
|
||||||
embd_mdl = TenantLLMService.model_instance(
|
embd_mdl = TenantLLMService.model_instance(
|
||||||
tenant_id, LLMType.EMBEDDING.value, embd_id
|
tenant_id, LLMType.EMBEDDING.value, embd_id
|
||||||
@ -1006,14 +1003,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content"]])
|
v, c = embd_mdl.encode([doc.name, req["content"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||||
d["chunk_id"] = chunk_id
|
|
||||||
d["kb_id"] = doc.kb_id
|
|
||||||
# rename keys
|
# rename keys
|
||||||
key_mapping = {
|
key_mapping = {
|
||||||
"chunk_id": "id",
|
"id": "id",
|
||||||
"content_with_weight": "content",
|
"content_with_weight": "content",
|
||||||
"doc_id": "document_id",
|
"doc_id": "document_id",
|
||||||
"important_kwd": "important_keywords",
|
"important_kwd": "important_keywords",
|
||||||
@ -1079,36 +1074,16 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
|||||||
"""
|
"""
|
||||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||||
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
|
|
||||||
if not doc:
|
|
||||||
return get_error_data_result(
|
|
||||||
message=f"You don't own the document {document_id}."
|
|
||||||
)
|
|
||||||
doc = doc[0]
|
|
||||||
req = request.json
|
req = request.json
|
||||||
if not req.get("chunk_ids"):
|
condition = {"doc_id": document_id}
|
||||||
return get_error_data_result("`chunk_ids` is required")
|
if "chunk_ids" in req:
|
||||||
query = {"doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True}
|
condition["id"] = req["chunk_ids"]
|
||||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||||
if not req:
|
if chunk_number != 0:
|
||||||
chunk_ids = None
|
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
||||||
else:
|
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
||||||
chunk_ids = req.get("chunk_ids")
|
return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(req["chunk_ids"])}")
|
||||||
if not chunk_ids:
|
return get_result(message=f"deleted {chunk_number} chunks")
|
||||||
chunk_list = sres.ids
|
|
||||||
else:
|
|
||||||
chunk_list = chunk_ids
|
|
||||||
for chunk_id in chunk_list:
|
|
||||||
if chunk_id not in sres.ids:
|
|
||||||
return get_error_data_result(f"Chunk {chunk_id} not found")
|
|
||||||
if not ELASTICSEARCH.deleteByQuery(
|
|
||||||
Q("ids", values=chunk_list), search.index_name(tenant_id)
|
|
||||||
):
|
|
||||||
return get_error_data_result(message="Index updating failure")
|
|
||||||
deleted_chunk_ids = chunk_list
|
|
||||||
chunk_number = len(deleted_chunk_ids)
|
|
||||||
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
|
|
||||||
return get_result()
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route(
|
@manager.route(
|
||||||
@ -1168,9 +1143,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
schema:
|
schema:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
try:
|
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||||
res = ELASTICSEARCH.get(chunk_id, search.index_name(tenant_id))
|
if chunk is None:
|
||||||
except Exception:
|
|
||||||
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
||||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
|
||||||
@ -1180,19 +1154,12 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
message=f"You don't own the document {document_id}."
|
message=f"You don't own the document {document_id}."
|
||||||
)
|
)
|
||||||
doc = doc[0]
|
doc = doc[0]
|
||||||
query = {
|
|
||||||
"doc_ids": [document_id],
|
|
||||||
"page": 1,
|
|
||||||
"size": 1024,
|
|
||||||
"question": "",
|
|
||||||
"sort": True,
|
|
||||||
}
|
|
||||||
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
|
||||||
if chunk_id not in sres.ids:
|
|
||||||
return get_error_data_result(f"You don't own the chunk {chunk_id}")
|
|
||||||
req = request.json
|
req = request.json
|
||||||
content = res["_source"].get("content_with_weight")
|
if "content" in req:
|
||||||
d = {"id": chunk_id, "content_with_weight": req.get("content", content)}
|
content = req["content"]
|
||||||
|
else:
|
||||||
|
content = chunk.get("content_with_weight", "")
|
||||||
|
d = {"id": chunk_id, "content_with_weight": content}
|
||||||
d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
|
d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
|
||||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||||
if "important_keywords" in req:
|
if "important_keywords" in req:
|
||||||
@ -1220,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
|
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ from api.utils.api_utils import (
|
|||||||
generate_confirmation_token,
|
generate_confirmation_token,
|
||||||
)
|
)
|
||||||
from api.versions import get_rag_version
|
from api.versions import get_rag_version
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
from api.settings import docStoreConn
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
@ -98,10 +98,11 @@ def status():
|
|||||||
res = {}
|
res = {}
|
||||||
st = timer()
|
st = timer()
|
||||||
try:
|
try:
|
||||||
res["es"] = ELASTICSEARCH.health()
|
res["doc_store"] = docStoreConn.health()
|
||||||
res["es"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
res["es"] = {
|
res["doc_store"] = {
|
||||||
|
"type": "unknown",
|
||||||
"status": "red",
|
"status": "red",
|
||||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
@ -470,7 +470,7 @@ class User(DataBaseModel, UserMixin):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
|
is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
|
||||||
@ -525,7 +525,7 @@ class Tenant(DataBaseModel):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
@ -542,7 +542,7 @@ class UserTenant(DataBaseModel):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
@ -559,7 +559,7 @@ class InvitationCode(DataBaseModel):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
@ -582,7 +582,7 @@ class LLMFactories(DataBaseModel):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
@ -616,7 +616,7 @@ class LLM(DataBaseModel):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
@ -703,7 +703,7 @@ class Knowledgebase(DataBaseModel):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
@ -767,7 +767,7 @@ class Document(DataBaseModel):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
@ -904,7 +904,7 @@ class Dialog(DataBaseModel):
|
|||||||
status = CharField(
|
status = CharField(
|
||||||
max_length=1,
|
max_length=1,
|
||||||
null=True,
|
null=True,
|
||||||
help_text="is it validate(0: wasted,1: validate)",
|
help_text="is it validate(0: wasted, 1: validate)",
|
||||||
default="1",
|
default="1",
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
@ -987,7 +987,7 @@ def migrate_db():
|
|||||||
help_text="where dose this document come from",
|
help_text="where dose this document come from",
|
||||||
index=True))
|
index=True))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
@ -996,7 +996,7 @@ def migrate_db():
|
|||||||
help_text="default rerank model ID"))
|
help_text="default rerank model ID"))
|
||||||
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
@ -1004,59 +1004,59 @@ def migrate_db():
|
|||||||
help_text="default rerank model ID"))
|
help_text="default rerank model ID"))
|
||||||
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
|
||||||
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.alter_column_type('tenant_llm', 'api_key',
|
migrator.alter_column_type('tenant_llm', 'api_key',
|
||||||
CharField(max_length=1024, null=True, help_text="API KEY", index=True))
|
CharField(max_length=1024, null=True, help_text="API KEY", index=True))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.add_column('api_token', 'source',
|
migrator.add_column('api_token', 'source',
|
||||||
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.add_column("tenant","tts_id",
|
migrator.add_column("tenant","tts_id",
|
||||||
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
|
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.add_column('api_4_conversation', 'source',
|
migrator.add_column('api_4_conversation', 'source',
|
||||||
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
|
DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
|
||||||
DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
|
DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.add_column('task', 'retry_count', IntegerField(default=0))
|
migrator.add_column('task', 'retry_count', IntegerField(default=0))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.alter_column_type('api_token', 'dialog_id',
|
migrator.alter_column_type('api_token', 'dialog_id',
|
||||||
CharField(max_length=32, null=True, index=True))
|
CharField(max_length=32, null=True, index=True))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
#
|
#
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
@ -24,16 +23,13 @@ from copy import deepcopy
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from elasticsearch_dsl import Q
|
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from api.db.db_utils import bulk_insert_into_db
|
from api.db.db_utils import bulk_insert_into_db
|
||||||
from api.settings import stat_logger
|
from api.settings import stat_logger, docStoreConn
|
||||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||||
from api.utils.file_utils import get_project_base_directory
|
|
||||||
from graphrag.mind_map_extractor import MindMapExtractor
|
from graphrag.mind_map_extractor import MindMapExtractor
|
||||||
from rag.settings import SVR_QUEUE_NAME
|
from rag.settings import SVR_QUEUE_NAME
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
|
|
||||||
@ -112,8 +108,7 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def remove_document(cls, doc, tenant_id):
|
def remove_document(cls, doc, tenant_id):
|
||||||
ELASTICSEARCH.deleteByQuery(
|
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
|
|
||||||
cls.clear_chunk_num(doc.id)
|
cls.clear_chunk_num(doc.id)
|
||||||
return cls.delete_by_id(doc.id)
|
return cls.delete_by_id(doc.id)
|
||||||
|
|
||||||
@ -225,6 +220,15 @@ class DocumentService(CommonService):
|
|||||||
return
|
return
|
||||||
return docs[0]["tenant_id"]
|
return docs[0]["tenant_id"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_knowledgebase_id(cls, doc_id):
|
||||||
|
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
|
||||||
|
docs = docs.dicts()
|
||||||
|
if not docs:
|
||||||
|
return
|
||||||
|
return docs[0]["kb_id"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_tenant_id_by_name(cls, name):
|
def get_tenant_id_by_name(cls, name):
|
||||||
@ -438,11 +442,6 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
if not e:
|
if not e:
|
||||||
raise LookupError("Can't find this knowledgebase!")
|
raise LookupError("Can't find this knowledgebase!")
|
||||||
|
|
||||||
idxnm = search.index_name(kb.tenant_id)
|
|
||||||
if not ELASTICSEARCH.indexExist(idxnm):
|
|
||||||
ELASTICSEARCH.createIdx(idxnm, json.load(
|
|
||||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||||
|
|
||||||
err, files = FileService.upload_document(kb, file_objs, user_id)
|
err, files = FileService.upload_document(kb, file_objs, user_id)
|
||||||
@ -486,7 +485,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
md5.update((ck["content_with_weight"] +
|
md5.update((ck["content_with_weight"] +
|
||||||
str(d["doc_id"])).encode("utf-8"))
|
str(d["doc_id"])).encode("utf-8"))
|
||||||
d["_id"] = md5.hexdigest()
|
d["id"] = md5.hexdigest()
|
||||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||||
if not d.get("image"):
|
if not d.get("image"):
|
||||||
@ -499,8 +498,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
else:
|
else:
|
||||||
d["image"].save(output_buffer, format='JPEG')
|
d["image"].save(output_buffer, format='JPEG')
|
||||||
|
|
||||||
STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue())
|
STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
|
||||||
d["img_id"] = "{}-{}".format(kb.id, d["_id"])
|
d["img_id"] = "{}-{}".format(kb.id, d["id"])
|
||||||
del d["image"]
|
del d["image"]
|
||||||
docs.append(d)
|
docs.append(d)
|
||||||
|
|
||||||
@ -520,6 +519,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
token_counts[doc_id] += c
|
token_counts[doc_id] += c
|
||||||
return vects
|
return vects
|
||||||
|
|
||||||
|
idxnm = search.index_name(kb.tenant_id)
|
||||||
|
try_create_idx = True
|
||||||
|
|
||||||
_, tenant = TenantService.get_by_id(kb.tenant_id)
|
_, tenant = TenantService.get_by_id(kb.tenant_id)
|
||||||
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
|
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||||
for doc_id in docids:
|
for doc_id in docids:
|
||||||
@ -550,7 +552,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
v = vects[i]
|
v = vects[i]
|
||||||
d["q_%d_vec" % len(v)] = v
|
d["q_%d_vec" % len(v)] = v
|
||||||
for b in range(0, len(cks), es_bulk_size):
|
for b in range(0, len(cks), es_bulk_size):
|
||||||
ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
|
if try_create_idx:
|
||||||
|
if not docStoreConn.indexExist(idxnm, kb_id):
|
||||||
|
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||||
|
try_create_idx = False
|
||||||
|
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||||
|
@ -66,6 +66,16 @@ class KnowledgebaseService(CommonService):
|
|||||||
|
|
||||||
return list(kbs.dicts())
|
return list(kbs.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_kb_ids(cls, tenant_id):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
]
|
||||||
|
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
|
||||||
|
kb_ids = [kb["id"] for kb in kbs]
|
||||||
|
return kb_ids
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_detail(cls, kb_id):
|
def get_detail(cls, kb_id):
|
||||||
|
@ -18,6 +18,8 @@ from datetime import date
|
|||||||
from enum import IntEnum, Enum
|
from enum import IntEnum, Enum
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from api.utils.log_utils import LoggerFactory, getLogger
|
from api.utils.log_utils import LoggerFactory, getLogger
|
||||||
|
import rag.utils.es_conn
|
||||||
|
import rag.utils.infinity_conn
|
||||||
|
|
||||||
# Logger
|
# Logger
|
||||||
LoggerFactory.set_directory(
|
LoggerFactory.set_directory(
|
||||||
@ -33,7 +35,7 @@ access_logger = getLogger("access")
|
|||||||
database_logger = getLogger("database")
|
database_logger = getLogger("database")
|
||||||
chat_logger = getLogger("chat")
|
chat_logger = getLogger("chat")
|
||||||
|
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
import rag.utils
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from graphrag import search as kg_search
|
from graphrag import search as kg_search
|
||||||
from api.utils import get_base_config, decrypt_database_config
|
from api.utils import get_base_config, decrypt_database_config
|
||||||
@ -206,8 +208,12 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
|
|||||||
PRIVILEGE_COMMAND_WHITELIST = []
|
PRIVILEGE_COMMAND_WHITELIST = []
|
||||||
CHECK_NODES_IDENTITY = False
|
CHECK_NODES_IDENTITY = False
|
||||||
|
|
||||||
retrievaler = search.Dealer(ELASTICSEARCH)
|
if 'username' in get_base_config("es", {}):
|
||||||
kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH)
|
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||||
|
else:
|
||||||
|
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||||
|
retrievaler = search.Dealer(docStoreConn)
|
||||||
|
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
||||||
|
|
||||||
|
|
||||||
class CustomEnum(Enum):
|
class CustomEnum(Enum):
|
||||||
|
@ -126,10 +126,6 @@ def server_error_response(e):
|
|||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||||
if repr(e).find("index_not_found_exception") >= 0:
|
|
||||||
return get_json_result(code=RetCode.EXCEPTION_ERROR,
|
|
||||||
message="No chunk found, please upload file and parse it.")
|
|
||||||
|
|
||||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||||
|
|
||||||
|
|
||||||
@ -270,10 +266,6 @@ def construct_error_response(e):
|
|||||||
pass
|
pass
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||||
if repr(e).find("index_not_found_exception") >= 0:
|
|
||||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
|
|
||||||
message="No chunk found, please upload file and parse it.")
|
|
||||||
|
|
||||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||||
|
|
||||||
|
|
||||||
@ -295,7 +287,7 @@ def token_required(func):
|
|||||||
return decorated_function
|
return decorated_function
|
||||||
|
|
||||||
|
|
||||||
def get_result(code=RetCode.SUCCESS, message='error', data=None):
|
def get_result(code=RetCode.SUCCESS, message="", data=None):
|
||||||
if code == 0:
|
if code == 0:
|
||||||
if data is not None:
|
if data is not None:
|
||||||
response = {"code": code, "data": data}
|
response = {"code": code, "data": data}
|
||||||
|
26
conf/infinity_mapping.json
Normal file
26
conf/infinity_mapping.json
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"id": {"type": "varchar", "default": ""},
|
||||||
|
"doc_id": {"type": "varchar", "default": ""},
|
||||||
|
"kb_id": {"type": "varchar", "default": ""},
|
||||||
|
"create_time": {"type": "varchar", "default": ""},
|
||||||
|
"create_timestamp_flt": {"type": "float", "default": 0.0},
|
||||||
|
"img_id": {"type": "varchar", "default": ""},
|
||||||
|
"docnm_kwd": {"type": "varchar", "default": ""},
|
||||||
|
"title_tks": {"type": "varchar", "default": ""},
|
||||||
|
"title_sm_tks": {"type": "varchar", "default": ""},
|
||||||
|
"name_kwd": {"type": "varchar", "default": ""},
|
||||||
|
"important_kwd": {"type": "varchar", "default": ""},
|
||||||
|
"important_tks": {"type": "varchar", "default": ""},
|
||||||
|
"content_with_weight": {"type": "varchar", "default": ""},
|
||||||
|
"content_ltks": {"type": "varchar", "default": ""},
|
||||||
|
"content_sm_ltks": {"type": "varchar", "default": ""},
|
||||||
|
"page_num_list": {"type": "varchar", "default": ""},
|
||||||
|
"top_list": {"type": "varchar", "default": ""},
|
||||||
|
"position_list": {"type": "varchar", "default": ""},
|
||||||
|
"weight_int": {"type": "integer", "default": 0},
|
||||||
|
"weight_flt": {"type": "float", "default": 0.0},
|
||||||
|
"rank_int": {"type": "integer", "default": 0},
|
||||||
|
"available_int": {"type": "integer", "default": 1},
|
||||||
|
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
|
||||||
|
"entities_kwd": {"type": "varchar", "default": ""}
|
||||||
|
}
|
@ -1,200 +1,203 @@
|
|||||||
{
|
{
|
||||||
"settings": {
|
"settings": {
|
||||||
"index": {
|
"index": {
|
||||||
"number_of_shards": 2,
|
"number_of_shards": 2,
|
||||||
"number_of_replicas": 0,
|
"number_of_replicas": 0,
|
||||||
"refresh_interval" : "1000ms"
|
"refresh_interval": "1000ms"
|
||||||
},
|
},
|
||||||
"similarity": {
|
"similarity": {
|
||||||
"scripted_sim": {
|
"scripted_sim": {
|
||||||
"type": "scripted",
|
"type": "scripted",
|
||||||
"script": {
|
"script": {
|
||||||
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
|
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"lat_lon": {"type": "geo_point", "store":"true"}
|
"lat_lon": {
|
||||||
},
|
"type": "geo_point",
|
||||||
"date_detection": "true",
|
"store": "true"
|
||||||
"dynamic_templates": [
|
}
|
||||||
{
|
},
|
||||||
"int": {
|
"date_detection": "true",
|
||||||
"match": "*_int",
|
"dynamic_templates": [
|
||||||
"mapping": {
|
{
|
||||||
"type": "integer",
|
"int": {
|
||||||
"store": "true"
|
"match": "*_int",
|
||||||
}
|
"mapping": {
|
||||||
|
"type": "integer",
|
||||||
|
"store": "true"
|
||||||
}
|
}
|
||||||
},
|
|
||||||
{
|
|
||||||
"ulong": {
|
|
||||||
"match": "*_ulong",
|
|
||||||
"mapping": {
|
|
||||||
"type": "unsigned_long",
|
|
||||||
"store": "true"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"long": {
|
|
||||||
"match": "*_long",
|
|
||||||
"mapping": {
|
|
||||||
"type": "long",
|
|
||||||
"store": "true"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"short": {
|
|
||||||
"match": "*_short",
|
|
||||||
"mapping": {
|
|
||||||
"type": "short",
|
|
||||||
"store": "true"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"numeric": {
|
|
||||||
"match": "*_flt",
|
|
||||||
"mapping": {
|
|
||||||
"type": "float",
|
|
||||||
"store": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tks": {
|
|
||||||
"match": "*_tks",
|
|
||||||
"mapping": {
|
|
||||||
"type": "text",
|
|
||||||
"similarity": "scripted_sim",
|
|
||||||
"analyzer": "whitespace",
|
|
||||||
"store": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ltks":{
|
|
||||||
"match": "*_ltks",
|
|
||||||
"mapping": {
|
|
||||||
"type": "text",
|
|
||||||
"analyzer": "whitespace",
|
|
||||||
"store": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"kwd": {
|
|
||||||
"match_pattern": "regex",
|
|
||||||
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
|
|
||||||
"mapping": {
|
|
||||||
"type": "keyword",
|
|
||||||
"similarity": "boolean",
|
|
||||||
"store": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dt": {
|
|
||||||
"match_pattern": "regex",
|
|
||||||
"match": "^.*(_dt|_time|_at)$",
|
|
||||||
"mapping": {
|
|
||||||
"type": "date",
|
|
||||||
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
|
|
||||||
"store": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"nested": {
|
|
||||||
"match": "*_nst",
|
|
||||||
"mapping": {
|
|
||||||
"type": "nested"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"object": {
|
|
||||||
"match": "*_obj",
|
|
||||||
"mapping": {
|
|
||||||
"type": "object",
|
|
||||||
"dynamic": "true"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"string": {
|
|
||||||
"match": "*_with_weight",
|
|
||||||
"mapping": {
|
|
||||||
"type": "text",
|
|
||||||
"index": "false",
|
|
||||||
"store": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"string": {
|
|
||||||
"match": "*_fea",
|
|
||||||
"mapping": {
|
|
||||||
"type": "rank_feature"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dense_vector": {
|
|
||||||
"match": "*_512_vec",
|
|
||||||
"mapping": {
|
|
||||||
"type": "dense_vector",
|
|
||||||
"index": true,
|
|
||||||
"similarity": "cosine",
|
|
||||||
"dims": 512
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dense_vector": {
|
|
||||||
"match": "*_768_vec",
|
|
||||||
"mapping": {
|
|
||||||
"type": "dense_vector",
|
|
||||||
"index": true,
|
|
||||||
"similarity": "cosine",
|
|
||||||
"dims": 768
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dense_vector": {
|
|
||||||
"match": "*_1024_vec",
|
|
||||||
"mapping": {
|
|
||||||
"type": "dense_vector",
|
|
||||||
"index": true,
|
|
||||||
"similarity": "cosine",
|
|
||||||
"dims": 1024
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dense_vector": {
|
|
||||||
"match": "*_1536_vec",
|
|
||||||
"mapping": {
|
|
||||||
"type": "dense_vector",
|
|
||||||
"index": true,
|
|
||||||
"similarity": "cosine",
|
|
||||||
"dims": 1536
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"binary": {
|
|
||||||
"match": "*_bin",
|
|
||||||
"mapping": {
|
|
||||||
"type": "binary"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
]
|
},
|
||||||
}
|
{
|
||||||
}
|
"ulong": {
|
||||||
|
"match": "*_ulong",
|
||||||
|
"mapping": {
|
||||||
|
"type": "unsigned_long",
|
||||||
|
"store": "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"long": {
|
||||||
|
"match": "*_long",
|
||||||
|
"mapping": {
|
||||||
|
"type": "long",
|
||||||
|
"store": "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"short": {
|
||||||
|
"match": "*_short",
|
||||||
|
"mapping": {
|
||||||
|
"type": "short",
|
||||||
|
"store": "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"numeric": {
|
||||||
|
"match": "*_flt",
|
||||||
|
"mapping": {
|
||||||
|
"type": "float",
|
||||||
|
"store": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tks": {
|
||||||
|
"match": "*_tks",
|
||||||
|
"mapping": {
|
||||||
|
"type": "text",
|
||||||
|
"similarity": "scripted_sim",
|
||||||
|
"analyzer": "whitespace",
|
||||||
|
"store": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ltks": {
|
||||||
|
"match": "*_ltks",
|
||||||
|
"mapping": {
|
||||||
|
"type": "text",
|
||||||
|
"analyzer": "whitespace",
|
||||||
|
"store": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"kwd": {
|
||||||
|
"match_pattern": "regex",
|
||||||
|
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
|
||||||
|
"mapping": {
|
||||||
|
"type": "keyword",
|
||||||
|
"similarity": "boolean",
|
||||||
|
"store": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dt": {
|
||||||
|
"match_pattern": "regex",
|
||||||
|
"match": "^.*(_dt|_time|_at)$",
|
||||||
|
"mapping": {
|
||||||
|
"type": "date",
|
||||||
|
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
|
||||||
|
"store": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nested": {
|
||||||
|
"match": "*_nst",
|
||||||
|
"mapping": {
|
||||||
|
"type": "nested"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"object": {
|
||||||
|
"match": "*_obj",
|
||||||
|
"mapping": {
|
||||||
|
"type": "object",
|
||||||
|
"dynamic": "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"string": {
|
||||||
|
"match": "*_(with_weight|list)$",
|
||||||
|
"mapping": {
|
||||||
|
"type": "text",
|
||||||
|
"index": "false",
|
||||||
|
"store": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"string": {
|
||||||
|
"match": "*_fea",
|
||||||
|
"mapping": {
|
||||||
|
"type": "rank_feature"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dense_vector": {
|
||||||
|
"match": "*_512_vec",
|
||||||
|
"mapping": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"index": true,
|
||||||
|
"similarity": "cosine",
|
||||||
|
"dims": 512
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dense_vector": {
|
||||||
|
"match": "*_768_vec",
|
||||||
|
"mapping": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"index": true,
|
||||||
|
"similarity": "cosine",
|
||||||
|
"dims": 768
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dense_vector": {
|
||||||
|
"match": "*_1024_vec",
|
||||||
|
"mapping": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"index": true,
|
||||||
|
"similarity": "cosine",
|
||||||
|
"dims": 1024
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dense_vector": {
|
||||||
|
"match": "*_1536_vec",
|
||||||
|
"mapping": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"index": true,
|
||||||
|
"similarity": "cosine",
|
||||||
|
"dims": 1536
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"binary": {
|
||||||
|
"match": "*_bin",
|
||||||
|
"mapping": {
|
||||||
|
"type": "binary"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
@ -19,6 +19,11 @@ KIBANA_PASSWORD=infini_rag_flow
|
|||||||
# Update it according to the available memory in the host machine.
|
# Update it according to the available memory in the host machine.
|
||||||
MEM_LIMIT=8073741824
|
MEM_LIMIT=8073741824
|
||||||
|
|
||||||
|
# Port to expose Infinity API to the host
|
||||||
|
INFINITY_THRIFT_PORT=23817
|
||||||
|
INFINITY_HTTP_PORT=23820
|
||||||
|
INFINITY_PSQL_PORT=5432
|
||||||
|
|
||||||
# The password for MySQL.
|
# The password for MySQL.
|
||||||
# When updated, you must revise the `mysql.password` entry in service_conf.yaml.
|
# When updated, you must revise the `mysql.password` entry in service_conf.yaml.
|
||||||
MYSQL_PASSWORD=infini_rag_flow
|
MYSQL_PASSWORD=infini_rag_flow
|
||||||
|
@ -6,6 +6,7 @@ services:
|
|||||||
- esdata01:/usr/share/elasticsearch/data
|
- esdata01:/usr/share/elasticsearch/data
|
||||||
ports:
|
ports:
|
||||||
- ${ES_PORT}:9200
|
- ${ES_PORT}:9200
|
||||||
|
env_file: .env
|
||||||
environment:
|
environment:
|
||||||
- node.name=es01
|
- node.name=es01
|
||||||
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
|
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
|
||||||
@ -27,12 +28,40 @@ services:
|
|||||||
retries: 120
|
retries: 120
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: always
|
restart: on-failure
|
||||||
|
|
||||||
|
# infinity:
|
||||||
|
# container_name: ragflow-infinity
|
||||||
|
# image: infiniflow/infinity:v0.5.0-dev2
|
||||||
|
# volumes:
|
||||||
|
# - infinity_data:/var/infinity
|
||||||
|
# ports:
|
||||||
|
# - ${INFINITY_THRIFT_PORT}:23817
|
||||||
|
# - ${INFINITY_HTTP_PORT}:23820
|
||||||
|
# - ${INFINITY_PSQL_PORT}:5432
|
||||||
|
# env_file: .env
|
||||||
|
# environment:
|
||||||
|
# - TZ=${TIMEZONE}
|
||||||
|
# mem_limit: ${MEM_LIMIT}
|
||||||
|
# ulimits:
|
||||||
|
# nofile:
|
||||||
|
# soft: 500000
|
||||||
|
# hard: 500000
|
||||||
|
# networks:
|
||||||
|
# - ragflow
|
||||||
|
# healthcheck:
|
||||||
|
# test: ["CMD", "curl", "http://localhost:23820/admin/node/current"]
|
||||||
|
# interval: 10s
|
||||||
|
# timeout: 10s
|
||||||
|
# retries: 120
|
||||||
|
# restart: on-failure
|
||||||
|
|
||||||
|
|
||||||
mysql:
|
mysql:
|
||||||
# mysql:5.7 linux/arm64 image is unavailable.
|
# mysql:5.7 linux/arm64 image is unavailable.
|
||||||
image: mysql:8.0.39
|
image: mysql:8.0.39
|
||||||
container_name: ragflow-mysql
|
container_name: ragflow-mysql
|
||||||
|
env_file: .env
|
||||||
environment:
|
environment:
|
||||||
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
|
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
|
||||||
- TZ=${TIMEZONE}
|
- TZ=${TIMEZONE}
|
||||||
@ -55,7 +84,7 @@ services:
|
|||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
restart: always
|
restart: on-failure
|
||||||
|
|
||||||
minio:
|
minio:
|
||||||
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
|
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
|
||||||
@ -64,6 +93,7 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- ${MINIO_PORT}:9000
|
- ${MINIO_PORT}:9000
|
||||||
- ${MINIO_CONSOLE_PORT}:9001
|
- ${MINIO_CONSOLE_PORT}:9001
|
||||||
|
env_file: .env
|
||||||
environment:
|
environment:
|
||||||
- MINIO_ROOT_USER=${MINIO_USER}
|
- MINIO_ROOT_USER=${MINIO_USER}
|
||||||
- MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
|
- MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
|
||||||
@ -72,25 +102,28 @@ services:
|
|||||||
- minio_data:/data
|
- minio_data:/data
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: always
|
restart: on-failure
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
image: valkey/valkey:8
|
image: valkey/valkey:8
|
||||||
container_name: ragflow-redis
|
container_name: ragflow-redis
|
||||||
command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
|
command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
|
||||||
|
env_file: .env
|
||||||
ports:
|
ports:
|
||||||
- ${REDIS_PORT}:6379
|
- ${REDIS_PORT}:6379
|
||||||
volumes:
|
volumes:
|
||||||
- redis_data:/data
|
- redis_data:/data
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: always
|
restart: on-failure
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
esdata01:
|
esdata01:
|
||||||
driver: local
|
driver: local
|
||||||
|
infinity_data:
|
||||||
|
driver: local
|
||||||
mysql_data:
|
mysql_data:
|
||||||
driver: local
|
driver: local
|
||||||
minio_data:
|
minio_data:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
include:
|
include:
|
||||||
- path: ./docker-compose-base.yml
|
- ./docker-compose-base.yml
|
||||||
env_file: ./.env
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
ragflow:
|
ragflow:
|
||||||
@ -15,19 +14,21 @@ services:
|
|||||||
- ${SVR_HTTP_PORT}:9380
|
- ${SVR_HTTP_PORT}:9380
|
||||||
- 80:80
|
- 80:80
|
||||||
- 443:443
|
- 443:443
|
||||||
- 5678:5678
|
|
||||||
volumes:
|
volumes:
|
||||||
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
||||||
- ./ragflow-logs:/ragflow/logs
|
- ./ragflow-logs:/ragflow/logs
|
||||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||||
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
||||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||||
|
env_file: .env
|
||||||
environment:
|
environment:
|
||||||
- TZ=${TIMEZONE}
|
- TZ=${TIMEZONE}
|
||||||
- HF_ENDPOINT=${HF_ENDPOINT}
|
- HF_ENDPOINT=${HF_ENDPOINT}
|
||||||
- MACOS=${MACOS}
|
- MACOS=${MACOS}
|
||||||
networks:
|
networks:
|
||||||
- ragflow
|
- ragflow
|
||||||
restart: always
|
restart: on-failure
|
||||||
|
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
|
||||||
|
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
|
@ -67,7 +67,7 @@ docker compose -f docker/docker-compose-base.yml up -d
|
|||||||
1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
|
||||||
|
|
||||||
```
|
```
|
||||||
127.0.0.1 es01 mysql minio redis
|
127.0.0.1 es01 infinity mysql minio redis
|
||||||
```
|
```
|
||||||
|
|
||||||
2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
||||||
|
@ -1280,7 +1280,7 @@ Success:
|
|||||||
"document_keyword": "1.txt",
|
"document_keyword": "1.txt",
|
||||||
"highlight": "<em>ragflow</em> content",
|
"highlight": "<em>ragflow</em> content",
|
||||||
"id": "d78435d142bd5cf6704da62c778795c5",
|
"id": "d78435d142bd5cf6704da62c778795c5",
|
||||||
"img_id": "",
|
"image_id": "",
|
||||||
"important_keywords": [
|
"important_keywords": [
|
||||||
""
|
""
|
||||||
],
|
],
|
||||||
|
@ -1351,7 +1351,7 @@ A list of `Chunk` objects representing references to the message, each containin
|
|||||||
The chunk ID.
|
The chunk ID.
|
||||||
- `content` `str`
|
- `content` `str`
|
||||||
The content of the chunk.
|
The content of the chunk.
|
||||||
- `image_id` `str`
|
- `img_id` `str`
|
||||||
The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
|
The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
|
||||||
- `document_id` `str`
|
- `document_id` `str`
|
||||||
The ID of the referenced document.
|
The ID of the referenced document.
|
||||||
|
@ -254,9 +254,12 @@ if __name__ == "__main__":
|
|||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler
|
from api.settings import retrievaler
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
|
||||||
|
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||||
|
|
||||||
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||||
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])]
|
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
||||||
info = {
|
info = {
|
||||||
"input_text": docs,
|
"input_text": docs,
|
||||||
"entity_specs": "organization, person",
|
"entity_specs": "organization, person",
|
||||||
|
@ -15,95 +15,90 @@
|
|||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from elasticsearch_dsl import Q, Search
|
from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
|
||||||
|
|
||||||
from rag.nlp.search import Dealer
|
from rag.nlp.search import Dealer
|
||||||
|
|
||||||
|
|
||||||
class KGSearch(Dealer):
|
class KGSearch(Dealer):
|
||||||
def search(self, req, idxnm, emb_mdl=None, highlight=False):
|
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
|
||||||
def merge_into_first(sres, title=""):
|
def merge_into_first(sres, title="") -> Dict[str, str]:
|
||||||
df,texts = [],[]
|
if not sres:
|
||||||
for d in sres["hits"]["hits"]:
|
return {}
|
||||||
|
content_with_weight = ""
|
||||||
|
df, texts = [],[]
|
||||||
|
for d in sres.values():
|
||||||
try:
|
try:
|
||||||
df.append(json.loads(d["_source"]["content_with_weight"]))
|
df.append(json.loads(d["content_with_weight"]))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
texts.append(d["_source"]["content_with_weight"])
|
texts.append(d["content_with_weight"])
|
||||||
pass
|
|
||||||
if not df and not texts: return False
|
|
||||||
if df:
|
if df:
|
||||||
try:
|
content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
|
||||||
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
|
content_with_weight = title + "\n" + "\n".join(texts)
|
||||||
return True
|
first_id = ""
|
||||||
|
first_source = {}
|
||||||
|
for k, v in sres.items():
|
||||||
|
first_id = id
|
||||||
|
first_source = deepcopy(v)
|
||||||
|
break
|
||||||
|
first_source["content_with_weight"] = content_with_weight
|
||||||
|
first_id = next(iter(sres))
|
||||||
|
return {first_id: first_source}
|
||||||
|
|
||||||
|
qst = req.get("question", "")
|
||||||
|
matchText, keywords = self.qryr.question(qst, min_match=0.05)
|
||||||
|
condition = self.get_filters(req)
|
||||||
|
|
||||||
|
## Entity retrieval
|
||||||
|
condition.update({"knowledge_graph_kwd": ["entity"]})
|
||||||
|
assert emb_mdl, "No embedding model selected"
|
||||||
|
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
|
||||||
|
q_vec = matchDense.embedding_data
|
||||||
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
||||||
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
|
"doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
|
||||||
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
|
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
|
||||||
"weight_int", "weight_flt", "rank_int"
|
"weight_int", "weight_flt", "rank_int"
|
||||||
])
|
])
|
||||||
|
|
||||||
qst = req.get("question", "")
|
fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
|
||||||
binary_query, keywords = self.qryr.question(qst, min_match="5%")
|
|
||||||
binary_query = self._add_filters(binary_query, req)
|
|
||||||
|
|
||||||
## Entity retrieval
|
ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
|
||||||
bqry = deepcopy(binary_query)
|
ent_res_fields = self.dataStore.getFields(ent_res, src)
|
||||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
|
entities = [d["name_kwd"] for d in ent_res_fields.values()]
|
||||||
s = Search()
|
ent_ids = self.dataStore.getChunkIds(ent_res)
|
||||||
s = s.query(bqry)[0: 32]
|
ent_content = merge_into_first(ent_res_fields, "-Entities-")
|
||||||
|
if ent_content:
|
||||||
s = s.to_dict()
|
ent_ids = list(ent_content.keys())
|
||||||
q_vec = []
|
|
||||||
if req.get("vector"):
|
|
||||||
assert emb_mdl, "No embedding model selected"
|
|
||||||
s["knn"] = self._vector(
|
|
||||||
qst, emb_mdl, req.get(
|
|
||||||
"similarity", 0.1), 1024)
|
|
||||||
s["knn"]["filter"] = bqry.to_dict()
|
|
||||||
q_vec = s["knn"]["query_vector"]
|
|
||||||
|
|
||||||
ent_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
|
|
||||||
entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
|
|
||||||
ent_ids = self.es.getDocIds(ent_res)
|
|
||||||
if merge_into_first(ent_res, "-Entities-"):
|
|
||||||
ent_ids = ent_ids[0:1]
|
|
||||||
|
|
||||||
## Community retrieval
|
## Community retrieval
|
||||||
bqry = deepcopy(binary_query)
|
condition = self.get_filters(req)
|
||||||
bqry.filter.append(Q("terms", entities_kwd=entities))
|
condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
|
||||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
|
comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
|
||||||
s = Search()
|
comm_res_fields = self.dataStore.getFields(comm_res, src)
|
||||||
s = s.query(bqry)[0: 32]
|
comm_ids = self.dataStore.getChunkIds(comm_res)
|
||||||
s = s.to_dict()
|
comm_content = merge_into_first(comm_res_fields, "-Community Report-")
|
||||||
comm_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
|
if comm_content:
|
||||||
comm_ids = self.es.getDocIds(comm_res)
|
comm_ids = list(comm_content.keys())
|
||||||
if merge_into_first(comm_res, "-Community Report-"):
|
|
||||||
comm_ids = comm_ids[0:1]
|
|
||||||
|
|
||||||
## Text content retrieval
|
## Text content retrieval
|
||||||
bqry = deepcopy(binary_query)
|
condition = self.get_filters(req)
|
||||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
|
condition.update({"knowledge_graph_kwd": ["text"]})
|
||||||
s = Search()
|
txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
|
||||||
s = s.query(bqry)[0: 6]
|
txt_res_fields = self.dataStore.getFields(txt_res, src)
|
||||||
s = s.to_dict()
|
txt_ids = self.dataStore.getChunkIds(txt_res)
|
||||||
txt_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
|
txt_content = merge_into_first(txt_res_fields, "-Original Content-")
|
||||||
txt_ids = self.es.getDocIds(txt_res)
|
if txt_content:
|
||||||
if merge_into_first(txt_res, "-Original Content-"):
|
txt_ids = list(txt_content.keys())
|
||||||
txt_ids = txt_ids[0:1]
|
|
||||||
|
|
||||||
return self.SearchResult(
|
return self.SearchResult(
|
||||||
total=len(ent_ids) + len(comm_ids) + len(txt_ids),
|
total=len(ent_ids) + len(comm_ids) + len(txt_ids),
|
||||||
ids=[*ent_ids, *comm_ids, *txt_ids],
|
ids=[*ent_ids, *comm_ids, *txt_ids],
|
||||||
query_vector=q_vec,
|
query_vector=q_vec,
|
||||||
aggregation=None,
|
|
||||||
highlight=None,
|
highlight=None,
|
||||||
field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
|
field={**ent_content, **comm_content, **txt_content},
|
||||||
keywords=[]
|
keywords=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,10 +31,13 @@ if __name__ == "__main__":
|
|||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler
|
from api.settings import retrievaler
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
|
||||||
|
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||||
|
|
||||||
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||||
docs = [d["content_with_weight"] for d in
|
docs = [d["content_with_weight"] for d in
|
||||||
retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=6, fields=["content_with_weight"])]
|
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
||||||
graph = ex(docs)
|
graph = ex(docs)
|
||||||
|
|
||||||
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||||
|
871
poetry.lock
generated
871
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -46,22 +46,23 @@ hanziconv = "0.3.2"
|
|||||||
html-text = "0.6.2"
|
html-text = "0.6.2"
|
||||||
httpx = "0.27.0"
|
httpx = "0.27.0"
|
||||||
huggingface-hub = "^0.25.0"
|
huggingface-hub = "^0.25.0"
|
||||||
infinity-emb = "0.0.51"
|
infinity-sdk = "0.5.0.dev2"
|
||||||
|
infinity-emb = "^0.0.66"
|
||||||
itsdangerous = "2.1.2"
|
itsdangerous = "2.1.2"
|
||||||
markdown = "3.6"
|
markdown = "3.6"
|
||||||
markdown-to-json = "2.1.1"
|
markdown-to-json = "2.1.1"
|
||||||
minio = "7.2.4"
|
minio = "7.2.4"
|
||||||
mistralai = "0.4.2"
|
mistralai = "0.4.2"
|
||||||
nltk = "3.9.1"
|
nltk = "3.9.1"
|
||||||
numpy = "1.26.4"
|
numpy = "^1.26.0"
|
||||||
ollama = "0.2.1"
|
ollama = "0.2.1"
|
||||||
onnxruntime = "1.19.2"
|
onnxruntime = "1.19.2"
|
||||||
openai = "1.45.0"
|
openai = "1.45.0"
|
||||||
opencv-python = "4.10.0.84"
|
opencv-python = "4.10.0.84"
|
||||||
opencv-python-headless = "4.10.0.84"
|
opencv-python-headless = "4.10.0.84"
|
||||||
openpyxl = "3.1.2"
|
openpyxl = "^3.1.0"
|
||||||
ormsgpack = "1.5.0"
|
ormsgpack = "1.5.0"
|
||||||
pandas = "2.2.2"
|
pandas = "^2.2.0"
|
||||||
pdfplumber = "0.10.4"
|
pdfplumber = "0.10.4"
|
||||||
peewee = "3.17.1"
|
peewee = "3.17.1"
|
||||||
pillow = "10.4.0"
|
pillow = "10.4.0"
|
||||||
@ -70,7 +71,7 @@ psycopg2-binary = "2.9.9"
|
|||||||
pyclipper = "1.3.0.post5"
|
pyclipper = "1.3.0.post5"
|
||||||
pycryptodomex = "3.20.0"
|
pycryptodomex = "3.20.0"
|
||||||
pypdf = "^5.0.0"
|
pypdf = "^5.0.0"
|
||||||
pytest = "8.2.2"
|
pytest = "^8.3.0"
|
||||||
python-dotenv = "1.0.1"
|
python-dotenv = "1.0.1"
|
||||||
python-dateutil = "2.8.2"
|
python-dateutil = "2.8.2"
|
||||||
python-pptx = "^1.0.2"
|
python-pptx = "^1.0.2"
|
||||||
@ -86,7 +87,7 @@ ruamel-base = "1.0.0"
|
|||||||
scholarly = "1.7.11"
|
scholarly = "1.7.11"
|
||||||
scikit-learn = "1.5.0"
|
scikit-learn = "1.5.0"
|
||||||
selenium = "4.22.0"
|
selenium = "4.22.0"
|
||||||
setuptools = "70.0.0"
|
setuptools = "^75.2.0"
|
||||||
shapely = "2.0.5"
|
shapely = "2.0.5"
|
||||||
six = "1.16.0"
|
six = "1.16.0"
|
||||||
strenum = "0.4.15"
|
strenum = "0.4.15"
|
||||||
@ -115,6 +116,7 @@ pymysql = "^1.1.1"
|
|||||||
mini-racer = "^0.12.4"
|
mini-racer = "^0.12.4"
|
||||||
pyicu = "^2.13.1"
|
pyicu = "^2.13.1"
|
||||||
flasgger = "^0.9.7.1"
|
flasgger = "^0.9.7.1"
|
||||||
|
polars = "^1.9.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.full]
|
[tool.poetry.group.full]
|
||||||
|
@ -20,6 +20,7 @@ from rag.nlp import tokenize, is_english
|
|||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
from deepdoc.parser import PdfParser, PptParser, PlainParser
|
from deepdoc.parser import PdfParser, PptParser, PlainParser
|
||||||
from PyPDF2 import PdfReader as pdf2_read
|
from PyPDF2 import PdfReader as pdf2_read
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class Ppt(PptParser):
|
class Ppt(PptParser):
|
||||||
@ -107,9 +108,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
d = copy.deepcopy(doc)
|
d = copy.deepcopy(doc)
|
||||||
pn += from_page
|
pn += from_page
|
||||||
d["image"] = img
|
d["image"] = img
|
||||||
d["page_num_int"] = [pn + 1]
|
d["page_num_list"] = json.dumps([pn + 1])
|
||||||
d["top_int"] = [0]
|
d["top_list"] = json.dumps([0])
|
||||||
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
d["position_list"] = json.dumps([(pn + 1, 0, img.size[0], 0, img.size[1])])
|
||||||
tokenize(d, txt, eng)
|
tokenize(d, txt, eng)
|
||||||
res.append(d)
|
res.append(d)
|
||||||
return res
|
return res
|
||||||
@ -123,10 +124,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
pn += from_page
|
pn += from_page
|
||||||
if img:
|
if img:
|
||||||
d["image"] = img
|
d["image"] = img
|
||||||
d["page_num_int"] = [pn + 1]
|
d["page_num_list"] = json.dumps([pn + 1])
|
||||||
d["top_int"] = [0]
|
d["top_list"] = json.dumps([0])
|
||||||
d["position_int"] = [
|
d["position_list"] = json.dumps([
|
||||||
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)])
|
||||||
tokenize(d, txt, eng)
|
tokenize(d, txt, eng)
|
||||||
res.append(d)
|
res.append(d)
|
||||||
return res
|
return res
|
||||||
|
@ -74,7 +74,7 @@ class Excel(ExcelParser):
|
|||||||
def trans_datatime(s):
|
def trans_datatime(s):
|
||||||
try:
|
try:
|
||||||
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
|
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ def column_data_type(arr):
|
|||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
arr[i] = trans[ty](str(arr[i]))
|
arr[i] = trans[ty](str(arr[i]))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
arr[i] = None
|
arr[i] = None
|
||||||
# if ty == "text":
|
# if ty == "text":
|
||||||
# if len(arr) > 128 and uni / len(arr) < 0.1:
|
# if len(arr) > 128 and uni / len(arr) < 0.1:
|
||||||
@ -182,7 +182,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
|
|||||||
"datetime": "_dt",
|
"datetime": "_dt",
|
||||||
"bool": "_kwd"}
|
"bool": "_kwd"}
|
||||||
for df in dfs:
|
for df in dfs:
|
||||||
for n in ["id", "_id", "index", "idx"]:
|
for n in ["id", "index", "idx"]:
|
||||||
if n in df.columns:
|
if n in df.columns:
|
||||||
del df[n]
|
del df[n]
|
||||||
clmns = df.columns.values
|
clmns = df.columns.values
|
||||||
|
590
rag/benchmark.py
590
rag/benchmark.py
@ -1,280 +1,310 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
# You may obtain a copy of the License at
|
# You may obtain a copy of the License at
|
||||||
#
|
#
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
#
|
#
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
import time
|
||||||
from copy import deepcopy
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
from api.db import LLMType
|
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db import LLMType
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.utils import get_uuid
|
from api.settings import retrievaler, docStoreConn
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils import get_uuid
|
||||||
from rag.nlp import tokenize, search
|
from rag.nlp import tokenize, search
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
from ranx import evaluate
|
||||||
from ranx import evaluate
|
import pandas as pd
|
||||||
import pandas as pd
|
from tqdm import tqdm
|
||||||
from tqdm import tqdm
|
|
||||||
from ranx import Qrels, Run
|
global max_docs
|
||||||
|
max_docs = sys.maxsize
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
def __init__(self, kb_id):
|
def __init__(self, kb_id):
|
||||||
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
|
self.kb_id = kb_id
|
||||||
self.similarity_threshold = self.kb.similarity_threshold
|
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
self.vector_similarity_weight = self.kb.vector_similarity_weight
|
self.similarity_threshold = self.kb.similarity_threshold
|
||||||
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
|
self.vector_similarity_weight = self.kb.vector_similarity_weight
|
||||||
|
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
|
||||||
def _get_benchmarks(self, query, dataset_idxnm, count=16):
|
self.tenant_id = ''
|
||||||
|
self.index_name = ''
|
||||||
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
|
self.initialized_index = False
|
||||||
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
|
|
||||||
return sres
|
def _get_retrieval(self, qrels):
|
||||||
|
# Need to wait for the ES and Infinity index to be ready
|
||||||
def _get_retrieval(self, qrels, dataset_idxnm):
|
time.sleep(20)
|
||||||
run = defaultdict(dict)
|
run = defaultdict(dict)
|
||||||
query_list = list(qrels.keys())
|
query_list = list(qrels.keys())
|
||||||
for query in query_list:
|
for query in query_list:
|
||||||
|
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||||
ranks = retrievaler.retrieval(query, self.embd_mdl,
|
0.0, self.vector_similarity_weight)
|
||||||
dataset_idxnm, [self.kb.id], 1, 30,
|
if len(ranks["chunks"]) == 0:
|
||||||
0.0, self.vector_similarity_weight)
|
print(f"deleted query: {query}")
|
||||||
for c in ranks["chunks"]:
|
del qrels[query]
|
||||||
if "vector" in c:
|
continue
|
||||||
del c["vector"]
|
for c in ranks["chunks"]:
|
||||||
run[query][c["chunk_id"]] = c["similarity"]
|
if "vector" in c:
|
||||||
|
del c["vector"]
|
||||||
return run
|
run[query][c["chunk_id"]] = c["similarity"]
|
||||||
|
return run
|
||||||
def embedding(self, docs, batch_size=16):
|
|
||||||
vects = []
|
def embedding(self, docs, batch_size=16):
|
||||||
cnts = [d["content_with_weight"] for d in docs]
|
vects = []
|
||||||
for i in range(0, len(cnts), batch_size):
|
cnts = [d["content_with_weight"] for d in docs]
|
||||||
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
|
for i in range(0, len(cnts), batch_size):
|
||||||
vects.extend(vts.tolist())
|
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
|
||||||
assert len(docs) == len(vects)
|
vects.extend(vts.tolist())
|
||||||
for i, d in enumerate(docs):
|
assert len(docs) == len(vects)
|
||||||
v = vects[i]
|
vector_size = 0
|
||||||
d["q_%d_vec" % len(v)] = v
|
for i, d in enumerate(docs):
|
||||||
return docs
|
v = vects[i]
|
||||||
|
vector_size = len(v)
|
||||||
@staticmethod
|
d["q_%d_vec" % len(v)] = v
|
||||||
def init_kb(index_name):
|
return docs, vector_size
|
||||||
idxnm = search.index_name(index_name)
|
|
||||||
if ELASTICSEARCH.indexExist(idxnm):
|
def init_index(self, vector_size: int):
|
||||||
ELASTICSEARCH.deleteIdx(search.index_name(index_name))
|
if self.initialized_index:
|
||||||
|
return
|
||||||
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
if docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||||
|
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||||
def ms_marco_index(self, file_path, index_name):
|
self.initialized_index = True
|
||||||
qrels = defaultdict(dict)
|
|
||||||
texts = defaultdict(dict)
|
def ms_marco_index(self, file_path, index_name):
|
||||||
docs = []
|
qrels = defaultdict(dict)
|
||||||
filelist = os.listdir(file_path)
|
texts = defaultdict(dict)
|
||||||
self.init_kb(index_name)
|
docs_count = 0
|
||||||
|
docs = []
|
||||||
max_workers = int(os.environ.get('MAX_WORKERS', 3))
|
filelist = sorted(os.listdir(file_path))
|
||||||
exe = ThreadPoolExecutor(max_workers=max_workers)
|
|
||||||
threads = []
|
for fn in filelist:
|
||||||
|
if docs_count >= max_docs:
|
||||||
def slow_actions(es_docs, idx_nm):
|
break
|
||||||
es_docs = self.embedding(es_docs)
|
if not fn.endswith(".parquet"):
|
||||||
ELASTICSEARCH.bulk(es_docs, idx_nm)
|
continue
|
||||||
return True
|
data = pd.read_parquet(os.path.join(file_path, fn))
|
||||||
|
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
|
||||||
for dir in filelist:
|
if docs_count >= max_docs:
|
||||||
data = pd.read_parquet(os.path.join(file_path, dir))
|
break
|
||||||
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
|
query = data.iloc[i]['query']
|
||||||
|
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
||||||
query = data.iloc[i]['query']
|
d = {
|
||||||
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
"id": get_uuid(),
|
||||||
d = {
|
"kb_id": self.kb.id,
|
||||||
"id": get_uuid(),
|
"docnm_kwd": "xxxxx",
|
||||||
"kb_id": self.kb.id,
|
"doc_id": "ksksks"
|
||||||
"docnm_kwd": "xxxxx",
|
}
|
||||||
"doc_id": "ksksks"
|
tokenize(d, text, "english")
|
||||||
}
|
docs.append(d)
|
||||||
tokenize(d, text, "english")
|
texts[d["id"]] = text
|
||||||
docs.append(d)
|
qrels[query][d["id"]] = int(rel)
|
||||||
texts[d["id"]] = text
|
if len(docs) >= 32:
|
||||||
qrels[query][d["id"]] = int(rel)
|
docs_count += len(docs)
|
||||||
if len(docs) >= 32:
|
docs, vector_size = self.embedding(docs)
|
||||||
threads.append(
|
self.init_index(vector_size)
|
||||||
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
threads.append(
|
if docs:
|
||||||
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
docs, vector_size = self.embedding(docs)
|
||||||
|
self.init_index(vector_size)
|
||||||
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
|
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||||
if not threads[i].result().output:
|
return qrels, texts
|
||||||
print("Indexing error...")
|
|
||||||
|
def trivia_qa_index(self, file_path, index_name):
|
||||||
return qrels, texts
|
qrels = defaultdict(dict)
|
||||||
|
texts = defaultdict(dict)
|
||||||
def trivia_qa_index(self, file_path, index_name):
|
docs_count = 0
|
||||||
qrels = defaultdict(dict)
|
docs = []
|
||||||
texts = defaultdict(dict)
|
filelist = sorted(os.listdir(file_path))
|
||||||
docs = []
|
for fn in filelist:
|
||||||
filelist = os.listdir(file_path)
|
if docs_count >= max_docs:
|
||||||
for dir in filelist:
|
break
|
||||||
data = pd.read_parquet(os.path.join(file_path, dir))
|
if not fn.endswith(".parquet"):
|
||||||
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir):
|
continue
|
||||||
query = data.iloc[i]['question']
|
data = pd.read_parquet(os.path.join(file_path, fn))
|
||||||
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
|
||||||
data.iloc[i]["search_results"]['search_context']):
|
if docs_count >= max_docs:
|
||||||
d = {
|
break
|
||||||
"id": get_uuid(),
|
query = data.iloc[i]['question']
|
||||||
"kb_id": self.kb.id,
|
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
||||||
"docnm_kwd": "xxxxx",
|
data.iloc[i]["search_results"]['search_context']):
|
||||||
"doc_id": "ksksks"
|
d = {
|
||||||
}
|
"id": get_uuid(),
|
||||||
tokenize(d, text, "english")
|
"kb_id": self.kb.id,
|
||||||
docs.append(d)
|
"docnm_kwd": "xxxxx",
|
||||||
texts[d["id"]] = text
|
"doc_id": "ksksks"
|
||||||
qrels[query][d["id"]] = int(rel)
|
}
|
||||||
if len(docs) >= 32:
|
tokenize(d, text, "english")
|
||||||
docs = self.embedding(docs)
|
docs.append(d)
|
||||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
texts[d["id"]] = text
|
||||||
docs = []
|
qrels[query][d["id"]] = int(rel)
|
||||||
|
if len(docs) >= 32:
|
||||||
docs = self.embedding(docs)
|
docs_count += len(docs)
|
||||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
docs, vector_size = self.embedding(docs)
|
||||||
return qrels, texts
|
self.init_index(vector_size)
|
||||||
|
docStoreConn.insert(docs,self.index_name)
|
||||||
def miracl_index(self, file_path, corpus_path, index_name):
|
docs = []
|
||||||
|
|
||||||
corpus_total = {}
|
docs, vector_size = self.embedding(docs)
|
||||||
for corpus_file in os.listdir(corpus_path):
|
self.init_index(vector_size)
|
||||||
tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
|
docStoreConn.insert(docs, self.index_name)
|
||||||
for index, i in tmp_data.iterrows():
|
return qrels, texts
|
||||||
corpus_total[i['docid']] = i['text']
|
|
||||||
|
def miracl_index(self, file_path, corpus_path, index_name):
|
||||||
topics_total = {}
|
corpus_total = {}
|
||||||
for topics_file in os.listdir(os.path.join(file_path, 'topics')):
|
for corpus_file in os.listdir(corpus_path):
|
||||||
if 'test' in topics_file:
|
tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
|
||||||
continue
|
for index, i in tmp_data.iterrows():
|
||||||
tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
|
corpus_total[i['docid']] = i['text']
|
||||||
for index, i in tmp_data.iterrows():
|
|
||||||
topics_total[i['qid']] = i['query']
|
topics_total = {}
|
||||||
|
for topics_file in os.listdir(os.path.join(file_path, 'topics')):
|
||||||
qrels = defaultdict(dict)
|
if 'test' in topics_file:
|
||||||
texts = defaultdict(dict)
|
continue
|
||||||
docs = []
|
tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
|
||||||
for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
|
for index, i in tmp_data.iterrows():
|
||||||
if 'test' in qrels_file:
|
topics_total[i['qid']] = i['query']
|
||||||
continue
|
|
||||||
|
qrels = defaultdict(dict)
|
||||||
tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
|
texts = defaultdict(dict)
|
||||||
names=['qid', 'Q0', 'docid', 'relevance'])
|
docs_count = 0
|
||||||
for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
|
docs = []
|
||||||
query = topics_total[tmp_data.iloc[i]['qid']]
|
for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
|
||||||
text = corpus_total[tmp_data.iloc[i]['docid']]
|
if 'test' in qrels_file:
|
||||||
rel = tmp_data.iloc[i]['relevance']
|
continue
|
||||||
d = {
|
if docs_count >= max_docs:
|
||||||
"id": get_uuid(),
|
break
|
||||||
"kb_id": self.kb.id,
|
|
||||||
"docnm_kwd": "xxxxx",
|
tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
|
||||||
"doc_id": "ksksks"
|
names=['qid', 'Q0', 'docid', 'relevance'])
|
||||||
}
|
for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
|
||||||
tokenize(d, text, 'english')
|
if docs_count >= max_docs:
|
||||||
docs.append(d)
|
break
|
||||||
texts[d["id"]] = text
|
query = topics_total[tmp_data.iloc[i]['qid']]
|
||||||
qrels[query][d["id"]] = int(rel)
|
text = corpus_total[tmp_data.iloc[i]['docid']]
|
||||||
if len(docs) >= 32:
|
rel = tmp_data.iloc[i]['relevance']
|
||||||
docs = self.embedding(docs)
|
d = {
|
||||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
"id": get_uuid(),
|
||||||
docs = []
|
"kb_id": self.kb.id,
|
||||||
|
"docnm_kwd": "xxxxx",
|
||||||
docs = self.embedding(docs)
|
"doc_id": "ksksks"
|
||||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
}
|
||||||
|
tokenize(d, text, 'english')
|
||||||
return qrels, texts
|
docs.append(d)
|
||||||
|
texts[d["id"]] = text
|
||||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
qrels[query][d["id"]] = int(rel)
|
||||||
keep_result = []
|
if len(docs) >= 32:
|
||||||
run_keys = list(run.keys())
|
docs_count += len(docs)
|
||||||
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
docs, vector_size = self.embedding(docs)
|
||||||
key = run_keys[run_i]
|
self.init_index(vector_size)
|
||||||
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
docStoreConn.insert(docs, self.index_name)
|
||||||
'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
|
docs = []
|
||||||
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
|
||||||
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
docs, vector_size = self.embedding(docs)
|
||||||
f.write('## Score For Every Query\n')
|
self.init_index(vector_size)
|
||||||
for keep_result_i in keep_result:
|
docStoreConn.insert(docs, self.index_name)
|
||||||
f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
|
return qrels, texts
|
||||||
scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
|
|
||||||
scores = sorted(scores, key=lambda kk: kk[1])
|
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||||
for score in scores[:10]:
|
keep_result = []
|
||||||
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
|
run_keys = list(run.keys())
|
||||||
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
|
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
||||||
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
|
key = run_keys[run_i]
|
||||||
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
|
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
||||||
|
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
|
||||||
def __call__(self, dataset, file_path, miracl_corpus=''):
|
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
||||||
if dataset == "ms_marco_v1.1":
|
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
||||||
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
f.write('## Score For Every Query\n')
|
||||||
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
|
for keep_result_i in keep_result:
|
||||||
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
|
||||||
self.save_results(qrels, run, texts, dataset, file_path)
|
scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
|
||||||
if dataset == "trivia_qa":
|
scores = sorted(scores, key=lambda kk: kk[1])
|
||||||
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
for score in scores[:10]:
|
||||||
run = self._get_retrieval(qrels, "benchmark_trivia_qa")
|
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
|
||||||
print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
|
||||||
self.save_results(qrels, run, texts, dataset, file_path)
|
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
|
||||||
if dataset == "miracl":
|
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
|
||||||
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
|
||||||
'yo', 'zh']:
|
def __call__(self, dataset, file_path, miracl_corpus=''):
|
||||||
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
|
if dataset == "ms_marco_v1.1":
|
||||||
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
|
self.tenant_id = "benchmark_ms_marco_v11"
|
||||||
continue
|
self.index_name = search.index_name(self.tenant_id)
|
||||||
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
|
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
||||||
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
|
run = self._get_retrieval(qrels)
|
||||||
continue
|
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
||||||
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
|
self.save_results(qrels, run, texts, dataset, file_path)
|
||||||
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
|
if dataset == "trivia_qa":
|
||||||
continue
|
self.tenant_id = "benchmark_trivia_qa"
|
||||||
if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
|
self.index_name = search.index_name(self.tenant_id)
|
||||||
print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
|
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
||||||
continue
|
run = self._get_retrieval(qrels)
|
||||||
qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
|
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
||||||
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
self.save_results(qrels, run, texts, dataset, file_path)
|
||||||
"benchmark_miracl_" + lang)
|
if dataset == "miracl":
|
||||||
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
|
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
||||||
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
'yo', 'zh']:
|
||||||
self.save_results(qrels, run, texts, dataset, file_path)
|
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
|
||||||
|
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
|
||||||
|
continue
|
||||||
if __name__ == '__main__':
|
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
|
||||||
print('*****************RAGFlow Benchmark*****************')
|
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
|
||||||
kb_id = input('Please input kb_id:\n')
|
continue
|
||||||
ex = Benchmark(kb_id)
|
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
|
||||||
dataset = input(
|
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
|
||||||
'RAGFlow Benchmark Support:\n\tms_marco_v1.1:<https://huggingface.co/datasets/microsoft/ms_marco>\n\ttrivia_qa:<https://huggingface.co/datasets/mandarjoshi/trivia_qa>\n\tmiracl:<https://huggingface.co/datasets/miracl/miracl>\nPlease input dataset choice:\n')
|
continue
|
||||||
if dataset in ['ms_marco_v1.1', 'trivia_qa']:
|
if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
|
||||||
if dataset == "ms_marco_v1.1":
|
print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
|
||||||
print("Notice: Please provide the ms_marco_v1.1 dataset only. ms_marco_v2.1 is not supported!")
|
continue
|
||||||
dataset_path = input('Please input ' + dataset + ' dataset path:\n')
|
self.tenant_id = "benchmark_miracl_" + lang
|
||||||
ex(dataset, dataset_path)
|
self.index_name = search.index_name(self.tenant_id)
|
||||||
elif dataset == 'miracl':
|
self.initialized_index = False
|
||||||
dataset_path = input('Please input ' + dataset + ' dataset path:\n')
|
qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
|
||||||
corpus_path = input('Please input ' + dataset + '-corpus dataset path:\n')
|
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
||||||
ex(dataset, dataset_path, miracl_corpus=corpus_path)
|
"benchmark_miracl_" + lang)
|
||||||
else:
|
run = self._get_retrieval(qrels)
|
||||||
print("Dataset: ", dataset, "not supported!")
|
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
||||||
|
self.save_results(qrels, run, texts, dataset, file_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('*****************RAGFlow Benchmark*****************')
|
||||||
|
parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
|
||||||
|
parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
|
||||||
|
parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
|
||||||
|
parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
|
||||||
|
parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
|
||||||
|
parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
max_docs = args.max_docs
|
||||||
|
kb_id = args.kb_id
|
||||||
|
ex = Benchmark(kb_id)
|
||||||
|
|
||||||
|
dataset = args.dataset
|
||||||
|
dataset_path = args.dataset_path
|
||||||
|
|
||||||
|
if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
|
||||||
|
ex(dataset, dataset_path)
|
||||||
|
elif dataset == "miracl":
|
||||||
|
if len(args) < 5:
|
||||||
|
print('Please input the correct parameters!')
|
||||||
|
exit(1)
|
||||||
|
miracl_corpus_path = args[4]
|
||||||
|
ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
|
||||||
|
else:
|
||||||
|
print("Dataset: ", dataset, "not supported!")
|
||||||
|
@ -25,6 +25,7 @@ import roman_numbers as r
|
|||||||
from word2number import w2n
|
from word2number import w2n
|
||||||
from cn2an import cn2an
|
from cn2an import cn2an
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import json
|
||||||
|
|
||||||
all_codecs = [
|
all_codecs = [
|
||||||
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
||||||
@ -51,12 +52,12 @@ def find_codec(blob):
|
|||||||
try:
|
try:
|
||||||
blob[:1024].decode(c)
|
blob[:1024].decode(c)
|
||||||
return c
|
return c
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
blob.decode(c)
|
blob.decode(c)
|
||||||
return c
|
return c
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return "utf-8"
|
return "utf-8"
|
||||||
@ -241,7 +242,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
|
|||||||
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
d["image"], poss = pdf_parser.crop(ck, need_position=True)
|
||||||
add_positions(d, poss)
|
add_positions(d, poss)
|
||||||
ck = pdf_parser.remove_tag(ck)
|
ck = pdf_parser.remove_tag(ck)
|
||||||
except NotImplementedError as e:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
tokenize(d, ck, eng)
|
tokenize(d, ck, eng)
|
||||||
res.append(d)
|
res.append(d)
|
||||||
@ -289,13 +290,16 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
|
|||||||
def add_positions(d, poss):
|
def add_positions(d, poss):
|
||||||
if not poss:
|
if not poss:
|
||||||
return
|
return
|
||||||
d["page_num_int"] = []
|
page_num_list = []
|
||||||
d["position_int"] = []
|
position_list = []
|
||||||
d["top_int"] = []
|
top_list = []
|
||||||
for pn, left, right, top, bottom in poss:
|
for pn, left, right, top, bottom in poss:
|
||||||
d["page_num_int"].append(int(pn + 1))
|
page_num_list.append(int(pn + 1))
|
||||||
d["top_int"].append(int(top))
|
top_list.append(int(top))
|
||||||
d["position_int"].append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
|
position_list.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
|
||||||
|
d["page_num_list"] = json.dumps(page_num_list)
|
||||||
|
d["position_list"] = json.dumps(position_list)
|
||||||
|
d["top_list"] = json.dumps(top_list)
|
||||||
|
|
||||||
|
|
||||||
def remove_contents_table(sections, eng=False):
|
def remove_contents_table(sections, eng=False):
|
||||||
|
112
rag/nlp/query.py
112
rag/nlp/query.py
@ -15,20 +15,25 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import copy
|
from rag.utils.doc_store_conn import MatchTextExpr
|
||||||
from elasticsearch_dsl import Q
|
|
||||||
|
|
||||||
from rag.nlp import rag_tokenizer, term_weight, synonym
|
from rag.nlp import rag_tokenizer, term_weight, synonym
|
||||||
|
|
||||||
class EsQueryer:
|
|
||||||
def __init__(self, es):
|
class FulltextQueryer:
|
||||||
|
def __init__(self):
|
||||||
self.tw = term_weight.Dealer()
|
self.tw = term_weight.Dealer()
|
||||||
self.es = es
|
|
||||||
self.syn = synonym.Dealer()
|
self.syn = synonym.Dealer()
|
||||||
self.flds = ["ask_tks^10", "ask_small_tks"]
|
self.query_fields = [
|
||||||
|
"title_tks^10",
|
||||||
|
"title_sm_tks^5",
|
||||||
|
"important_kwd^30",
|
||||||
|
"important_tks^20",
|
||||||
|
"content_ltks^2",
|
||||||
|
"content_sm_ltks",
|
||||||
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def subSpecialChar(line):
|
def subSpecialChar(line):
|
||||||
@ -43,12 +48,15 @@ class EsQueryer:
|
|||||||
for t in arr:
|
for t in arr:
|
||||||
if not re.match(r"[a-zA-Z]+$", t):
|
if not re.match(r"[a-zA-Z]+$", t):
|
||||||
e += 1
|
e += 1
|
||||||
return e * 1. / len(arr) >= 0.7
|
return e * 1.0 / len(arr) >= 0.7
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rmWWW(txt):
|
def rmWWW(txt):
|
||||||
patts = [
|
patts = [
|
||||||
(r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
(
|
||||||
|
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
|
||||||
|
"",
|
||||||
|
),
|
||||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||||
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
|
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
|
||||||
]
|
]
|
||||||
@ -56,16 +64,16 @@ class EsQueryer:
|
|||||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
def question(self, txt, tbl="qa", min_match="60%"):
|
def question(self, txt, tbl="qa", min_match:float=0.6):
|
||||||
txt = re.sub(
|
txt = re.sub(
|
||||||
r"[ :\r\n\t,,。??/`!!&\^%%]+",
|
r"[ :\r\n\t,,。??/`!!&\^%%]+",
|
||||||
" ",
|
" ",
|
||||||
rag_tokenizer.tradi2simp(
|
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||||
rag_tokenizer.strQ2B(
|
).strip()
|
||||||
txt.lower()))).strip()
|
txt = FulltextQueryer.rmWWW(txt)
|
||||||
|
|
||||||
if not self.isChinese(txt):
|
if not self.isChinese(txt):
|
||||||
txt = EsQueryer.rmWWW(txt)
|
txt = FulltextQueryer.rmWWW(txt)
|
||||||
tks = rag_tokenizer.tokenize(txt).split(" ")
|
tks = rag_tokenizer.tokenize(txt).split(" ")
|
||||||
tks_w = self.tw.weights(tks)
|
tks_w = self.tw.weights(tks)
|
||||||
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
||||||
@ -73,14 +81,20 @@ class EsQueryer:
|
|||||||
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
||||||
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
|
||||||
for i in range(1, len(tks_w)):
|
for i in range(1, len(tks_w)):
|
||||||
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
q.append(
|
||||||
|
'"%s %s"^%.4f'
|
||||||
|
% (
|
||||||
|
tks_w[i - 1][0],
|
||||||
|
tks_w[i][0],
|
||||||
|
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
if not q:
|
if not q:
|
||||||
q.append(txt)
|
q.append(txt)
|
||||||
return Q("bool",
|
query = " ".join(q)
|
||||||
must=Q("query_string", fields=self.flds,
|
return MatchTextExpr(
|
||||||
type="best_fields", query=" ".join(q),
|
self.query_fields, query, 100
|
||||||
boost=1)#, minimum_should_match=min_match)
|
), tks
|
||||||
), list(set([t for t in txt.split(" ") if t]))
|
|
||||||
|
|
||||||
def need_fine_grained_tokenize(tk):
|
def need_fine_grained_tokenize(tk):
|
||||||
if len(tk) < 3:
|
if len(tk) < 3:
|
||||||
@ -89,7 +103,7 @@ class EsQueryer:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
txt = EsQueryer.rmWWW(txt)
|
txt = FulltextQueryer.rmWWW(txt)
|
||||||
qs, keywords = [], []
|
qs, keywords = [], []
|
||||||
for tt in self.tw.split(txt)[:256]: # .split(" "):
|
for tt in self.tw.split(txt)[:256]: # .split(" "):
|
||||||
if not tt:
|
if not tt:
|
||||||
@ -101,65 +115,71 @@ class EsQueryer:
|
|||||||
logging.info(json.dumps(twts, ensure_ascii=False))
|
logging.info(json.dumps(twts, ensure_ascii=False))
|
||||||
tms = []
|
tms = []
|
||||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||||
sm = rag_tokenizer.fine_grained_tokenize(tk).split(" ") if need_fine_grained_tokenize(tk) else []
|
sm = (
|
||||||
|
rag_tokenizer.fine_grained_tokenize(tk).split(" ")
|
||||||
|
if need_fine_grained_tokenize(tk)
|
||||||
|
else []
|
||||||
|
)
|
||||||
sm = [
|
sm = [
|
||||||
re.sub(
|
re.sub(
|
||||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||||
"",
|
"",
|
||||||
m) for m in sm]
|
m,
|
||||||
sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
)
|
||||||
|
for m in sm
|
||||||
|
]
|
||||||
|
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||||
sm = [m for m in sm if len(m) > 1]
|
sm = [m for m in sm if len(m) > 1]
|
||||||
|
|
||||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||||
keywords.extend(sm)
|
keywords.extend(sm)
|
||||||
if len(keywords) >= 12: break
|
if len(keywords) >= 12:
|
||||||
|
break
|
||||||
|
|
||||||
tk_syns = self.syn.lookup(tk)
|
tk_syns = self.syn.lookup(tk)
|
||||||
tk = EsQueryer.subSpecialChar(tk)
|
tk = FulltextQueryer.subSpecialChar(tk)
|
||||||
if tk.find(" ") > 0:
|
if tk.find(" ") > 0:
|
||||||
tk = "\"%s\"" % tk
|
tk = '"%s"' % tk
|
||||||
if tk_syns:
|
if tk_syns:
|
||||||
tk = f"({tk} %s)" % " ".join(tk_syns)
|
tk = f"({tk} %s)" % " ".join(tk_syns)
|
||||||
if sm:
|
if sm:
|
||||||
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
|
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||||
" ".join(sm), " ".join(sm))
|
|
||||||
if tk.strip():
|
if tk.strip():
|
||||||
tms.append((tk, w))
|
tms.append((tk, w))
|
||||||
|
|
||||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||||
|
|
||||||
if len(twts) > 1:
|
if len(twts) > 1:
|
||||||
tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
|
tms += ' ("%s"~4)^1.5' % (" ".join([t for t, _ in twts]))
|
||||||
if re.match(r"[0-9a-z ]+$", tt):
|
if re.match(r"[0-9a-z ]+$", tt):
|
||||||
tms = f"(\"{tt}\" OR \"%s\")" % rag_tokenizer.tokenize(tt)
|
tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
|
||||||
|
|
||||||
syns = " OR ".join(
|
syns = " OR ".join(
|
||||||
["\"%s\"^0.7" % EsQueryer.subSpecialChar(rag_tokenizer.tokenize(s)) for s in syns])
|
[
|
||||||
|
'"%s"^0.7'
|
||||||
|
% FulltextQueryer.subSpecialChar(rag_tokenizer.tokenize(s))
|
||||||
|
for s in syns
|
||||||
|
]
|
||||||
|
)
|
||||||
if syns:
|
if syns:
|
||||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||||
|
|
||||||
qs.append(tms)
|
qs.append(tms)
|
||||||
|
|
||||||
flds = copy.deepcopy(self.flds)
|
|
||||||
mst = []
|
|
||||||
if qs:
|
if qs:
|
||||||
mst.append(
|
query = " OR ".join([f"({t})" for t in qs if t])
|
||||||
Q("query_string", fields=flds, type="best_fields",
|
return MatchTextExpr(
|
||||||
query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
|
self.query_fields, query, 100, {"minimum_should_match": min_match}
|
||||||
)
|
), keywords
|
||||||
|
return None, keywords
|
||||||
|
|
||||||
return Q("bool",
|
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||||
must=mst,
|
|
||||||
), list(set(keywords))
|
|
||||||
|
|
||||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
|
|
||||||
vtweight=0.7):
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
sims = CosineSimilarity([avec], bvecs)
|
sims = CosineSimilarity([avec], bvecs)
|
||||||
tksim = self.token_similarity(atks, btkss)
|
tksim = self.token_similarity(atks, btkss)
|
||||||
return np.array(sims[0]) * vtweight + \
|
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||||
np.array(tksim) * tkweight, tksim, sims[0]
|
|
||||||
|
|
||||||
def token_similarity(self, atks, btkss):
|
def token_similarity(self, atks, btkss):
|
||||||
def toDict(tks):
|
def toDict(tks):
|
||||||
|
@ -14,34 +14,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
import json
|
||||||
|
|
||||||
from elasticsearch_dsl import Q, Search
|
|
||||||
from typing import List, Optional, Dict, Union
|
from typing import List, Optional, Dict, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from rag.settings import es_logger
|
from rag.settings import doc_store_logger
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from rag.nlp import rag_tokenizer, query, is_english
|
from rag.nlp import rag_tokenizer, query
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||||
|
|
||||||
|
|
||||||
def index_name(uid): return f"ragflow_{uid}"
|
def index_name(uid): return f"ragflow_{uid}"
|
||||||
|
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
def __init__(self, es):
|
def __init__(self, dataStore: DocStoreConnection):
|
||||||
self.qryr = query.EsQueryer(es)
|
self.qryr = query.FulltextQueryer()
|
||||||
self.qryr.flds = [
|
self.dataStore = dataStore
|
||||||
"title_tks^10",
|
|
||||||
"title_sm_tks^5",
|
|
||||||
"important_kwd^30",
|
|
||||||
"important_tks^20",
|
|
||||||
"content_ltks^2",
|
|
||||||
"content_sm_ltks"]
|
|
||||||
self.es = es
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchResult:
|
class SearchResult:
|
||||||
@ -54,170 +45,99 @@ class Dealer:
|
|||||||
keywords: Optional[List[str]] = None
|
keywords: Optional[List[str]] = None
|
||||||
group_docs: List[List] = None
|
group_docs: List[List] = None
|
||||||
|
|
||||||
def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
|
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||||
qv, c = emb_mdl.encode_queries(txt)
|
qv, _ = emb_mdl.encode_queries(txt)
|
||||||
return {
|
embedding_data = [float(v) for v in qv]
|
||||||
"field": "q_%d_vec" % len(qv),
|
vector_column_name = f"q_{len(embedding_data)}_vec"
|
||||||
"k": topk,
|
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
||||||
"similarity": sim,
|
|
||||||
"num_candidates": topk * 2,
|
|
||||||
"query_vector": [float(v) for v in qv]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _add_filters(self, bqry, req):
|
def get_filters(self, req):
|
||||||
if req.get("kb_ids"):
|
condition = dict()
|
||||||
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
|
||||||
if req.get("doc_ids"):
|
if key in req and req[key] is not None:
|
||||||
bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
|
condition[field] = req[key]
|
||||||
if req.get("knowledge_graph_kwd"):
|
# TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
|
||||||
bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"]))
|
for key in ["knowledge_graph_kwd"]:
|
||||||
if "available_int" in req:
|
if key in req and req[key] is not None:
|
||||||
if req["available_int"] == 0:
|
condition[key] = req[key]
|
||||||
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
return condition
|
||||||
else:
|
|
||||||
bqry.filter.append(
|
|
||||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
|
||||||
return bqry
|
|
||||||
|
|
||||||
def search(self, req, idxnms, emb_mdl=None, highlight=False):
|
def search(self, req, idx_names: list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
|
||||||
qst = req.get("question", "")
|
filters = self.get_filters(req)
|
||||||
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
orderBy = OrderByExpr()
|
||||||
bqry = self._add_filters(bqry, req)
|
|
||||||
bqry.boost = 0.05
|
|
||||||
|
|
||||||
s = Search()
|
|
||||||
pg = int(req.get("page", 1)) - 1
|
pg = int(req.get("page", 1)) - 1
|
||||||
topk = int(req.get("topk", 1024))
|
topk = int(req.get("topk", 1024))
|
||||||
ps = int(req.get("size", topk))
|
ps = int(req.get("size", topk))
|
||||||
|
offset, limit = pg * ps, (pg + 1) * ps
|
||||||
|
|
||||||
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
|
||||||
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd",
|
"doc_id", "position_list", "knowledge_graph_kwd",
|
||||||
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
|
"available_int", "content_with_weight"])
|
||||||
|
|
||||||
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
|
|
||||||
s = s.highlight("content_ltks")
|
|
||||||
s = s.highlight("title_ltks")
|
|
||||||
if not qst:
|
|
||||||
if not req.get("sort"):
|
|
||||||
s = s.sort(
|
|
||||||
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
|
|
||||||
{"create_timestamp_flt": {
|
|
||||||
"order": "desc", "unmapped_type": "float"}}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
s = s.sort(
|
|
||||||
{"page_num_int": {"order": "asc", "unmapped_type": "float",
|
|
||||||
"mode": "avg", "numeric_type": "double"}},
|
|
||||||
{"top_int": {"order": "asc", "unmapped_type": "float",
|
|
||||||
"mode": "avg", "numeric_type": "double"}},
|
|
||||||
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
|
|
||||||
{"create_timestamp_flt": {
|
|
||||||
"order": "desc", "unmapped_type": "float"}}
|
|
||||||
)
|
|
||||||
|
|
||||||
if qst:
|
|
||||||
s = s.highlight_options(
|
|
||||||
fragment_size=120,
|
|
||||||
number_of_fragments=5,
|
|
||||||
boundary_scanner_locale="zh-CN",
|
|
||||||
boundary_scanner="SENTENCE",
|
|
||||||
boundary_chars=",./;:\\!(),。?:!……()——、"
|
|
||||||
)
|
|
||||||
s = s.to_dict()
|
|
||||||
q_vec = []
|
|
||||||
if req.get("vector"):
|
|
||||||
assert emb_mdl, "No embedding model selected"
|
|
||||||
s["knn"] = self._vector(
|
|
||||||
qst, emb_mdl, req.get(
|
|
||||||
"similarity", 0.1), topk)
|
|
||||||
s["knn"]["filter"] = bqry.to_dict()
|
|
||||||
if not highlight and "highlight" in s:
|
|
||||||
del s["highlight"]
|
|
||||||
q_vec = s["knn"]["query_vector"]
|
|
||||||
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
|
||||||
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
|
|
||||||
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
|
||||||
if self.es.getTotal(res) == 0 and "knn" in s:
|
|
||||||
bqry, _ = self.qryr.question(qst, min_match="10%")
|
|
||||||
if req.get("doc_ids"):
|
|
||||||
bqry = Q("bool", must=[])
|
|
||||||
bqry = self._add_filters(bqry, req)
|
|
||||||
s["query"] = bqry.to_dict()
|
|
||||||
s["knn"]["filter"] = bqry.to_dict()
|
|
||||||
s["knn"]["similarity"] = 0.17
|
|
||||||
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
|
|
||||||
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
|
||||||
|
|
||||||
kwds = set([])
|
kwds = set([])
|
||||||
for k in keywords:
|
|
||||||
kwds.add(k)
|
|
||||||
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
|
|
||||||
if len(kk) < 2:
|
|
||||||
continue
|
|
||||||
if kk in kwds:
|
|
||||||
continue
|
|
||||||
kwds.add(kk)
|
|
||||||
|
|
||||||
aggs = self.getAggregation(res, "docnm_kwd")
|
qst = req.get("question", "")
|
||||||
|
q_vec = []
|
||||||
|
if not qst:
|
||||||
|
if req.get("sort"):
|
||||||
|
orderBy.desc("create_timestamp_flt")
|
||||||
|
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||||
|
total=self.dataStore.getTotal(res)
|
||||||
|
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||||
|
else:
|
||||||
|
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
||||||
|
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||||
|
if emb_mdl is None:
|
||||||
|
matchExprs = [matchText]
|
||||||
|
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
|
||||||
|
total=self.dataStore.getTotal(res)
|
||||||
|
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||||
|
else:
|
||||||
|
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||||
|
q_vec = matchDense.embedding_data
|
||||||
|
src.append(f"q_{len(q_vec)}_vec")
|
||||||
|
|
||||||
|
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
||||||
|
matchExprs = [matchText, matchDense, fusionExpr]
|
||||||
|
|
||||||
|
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
|
||||||
|
total=self.dataStore.getTotal(res)
|
||||||
|
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
|
||||||
|
|
||||||
|
# If result is empty, try again with lower min_match
|
||||||
|
if total == 0:
|
||||||
|
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||||
|
if "doc_ids" in filters:
|
||||||
|
del filters["doc_ids"]
|
||||||
|
matchDense.extra_options["similarity"] = 0.17
|
||||||
|
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
|
||||||
|
total=self.dataStore.getTotal(res)
|
||||||
|
doc_store_logger.info("Dealer.search 2 TOTAL: {}".format(total))
|
||||||
|
|
||||||
|
for k in keywords:
|
||||||
|
kwds.add(k)
|
||||||
|
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
|
||||||
|
if len(kk) < 2:
|
||||||
|
continue
|
||||||
|
if kk in kwds:
|
||||||
|
continue
|
||||||
|
kwds.add(kk)
|
||||||
|
|
||||||
|
doc_store_logger.info(f"TOTAL: {total}")
|
||||||
|
ids=self.dataStore.getChunkIds(res)
|
||||||
|
keywords=list(kwds)
|
||||||
|
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
||||||
|
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
||||||
return self.SearchResult(
|
return self.SearchResult(
|
||||||
total=self.es.getTotal(res),
|
total=total,
|
||||||
ids=self.es.getDocIds(res),
|
ids=ids,
|
||||||
query_vector=q_vec,
|
query_vector=q_vec,
|
||||||
aggregation=aggs,
|
aggregation=aggs,
|
||||||
highlight=self.getHighlight(res, keywords, "content_with_weight"),
|
highlight=highlight,
|
||||||
field=self.getFields(res, src),
|
field=self.dataStore.getFields(res, src),
|
||||||
keywords=list(kwds)
|
keywords=keywords
|
||||||
)
|
)
|
||||||
|
|
||||||
def getAggregation(self, res, g):
|
|
||||||
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
|
|
||||||
return
|
|
||||||
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
|
||||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
|
||||||
|
|
||||||
def getHighlight(self, res, keywords, fieldnm):
|
|
||||||
ans = {}
|
|
||||||
for d in res["hits"]["hits"]:
|
|
||||||
hlts = d.get("highlight")
|
|
||||||
if not hlts:
|
|
||||||
continue
|
|
||||||
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
|
||||||
if not is_english(txt.split(" ")):
|
|
||||||
ans[d["_id"]] = txt
|
|
||||||
continue
|
|
||||||
|
|
||||||
txt = d["_source"][fieldnm]
|
|
||||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
|
||||||
txts = []
|
|
||||||
for t in re.split(r"[.?!;\n]", txt):
|
|
||||||
for w in keywords:
|
|
||||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
|
||||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
|
|
||||||
txts.append(t)
|
|
||||||
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
|
||||||
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def getFields(self, sres, flds):
|
|
||||||
res = {}
|
|
||||||
if not flds:
|
|
||||||
return {}
|
|
||||||
for d in self.es.getSource(sres):
|
|
||||||
m = {n: d.get(n) for n in flds if d.get(n) is not None}
|
|
||||||
for n, v in m.items():
|
|
||||||
if isinstance(v, type([])):
|
|
||||||
m[n] = "\t".join([str(vv) if not isinstance(
|
|
||||||
vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v])
|
|
||||||
continue
|
|
||||||
if not isinstance(v, type("")):
|
|
||||||
m[n] = str(m[n])
|
|
||||||
#if n.find("tks") > 0:
|
|
||||||
# m[n] = rmSpace(m[n])
|
|
||||||
|
|
||||||
if m:
|
|
||||||
res[d["id"]] = m
|
|
||||||
return res
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trans2floats(txt):
|
def trans2floats(txt):
|
||||||
return [float(t) for t in txt.split("\t")]
|
return [float(t) for t in txt.split("\t")]
|
||||||
@ -260,7 +180,7 @@ class Dealer:
|
|||||||
continue
|
continue
|
||||||
idx.append(i)
|
idx.append(i)
|
||||||
pieces_.append(t)
|
pieces_.append(t)
|
||||||
es_logger.info("{} => {}".format(answer, pieces_))
|
doc_store_logger.info("{} => {}".format(answer, pieces_))
|
||||||
if not pieces_:
|
if not pieces_:
|
||||||
return answer, set([])
|
return answer, set([])
|
||||||
|
|
||||||
@ -281,7 +201,7 @@ class Dealer:
|
|||||||
chunks_tks,
|
chunks_tks,
|
||||||
tkweight, vtweight)
|
tkweight, vtweight)
|
||||||
mx = np.max(sim) * 0.99
|
mx = np.max(sim) * 0.99
|
||||||
es_logger.info("{} SIM: {}".format(pieces_[i], mx))
|
doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx))
|
||||||
if mx < thr:
|
if mx < thr:
|
||||||
continue
|
continue
|
||||||
cites[idx[i]] = list(
|
cites[idx[i]] = list(
|
||||||
@ -309,9 +229,15 @@ class Dealer:
|
|||||||
def rerank(self, sres, query, tkweight=0.3,
|
def rerank(self, sres, query, tkweight=0.3,
|
||||||
vtweight=0.7, cfield="content_ltks"):
|
vtweight=0.7, cfield="content_ltks"):
|
||||||
_, keywords = self.qryr.question(query)
|
_, keywords = self.qryr.question(query)
|
||||||
ins_embd = [
|
vector_size = len(sres.query_vector)
|
||||||
Dealer.trans2floats(
|
vector_column = f"q_{vector_size}_vec"
|
||||||
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
|
zero_vector = [0.0] * vector_size
|
||||||
|
ins_embd = []
|
||||||
|
for chunk_id in sres.ids:
|
||||||
|
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
||||||
|
if isinstance(vector, str):
|
||||||
|
vector = [float(v) for v in vector.split("\t")]
|
||||||
|
ins_embd.append(vector)
|
||||||
if not ins_embd:
|
if not ins_embd:
|
||||||
return [], [], []
|
return [], [], []
|
||||||
|
|
||||||
@ -377,7 +303,7 @@ class Dealer:
|
|||||||
if isinstance(tenant_ids, str):
|
if isinstance(tenant_ids, str):
|
||||||
tenant_ids = tenant_ids.split(",")
|
tenant_ids = tenant_ids.split(",")
|
||||||
|
|
||||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
|
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
|
||||||
ranks["total"] = sres.total
|
ranks["total"] = sres.total
|
||||||
|
|
||||||
if page <= RERANK_PAGE_LIMIT:
|
if page <= RERANK_PAGE_LIMIT:
|
||||||
@ -393,6 +319,8 @@ class Dealer:
|
|||||||
idx = list(range(len(sres.ids)))
|
idx = list(range(len(sres.ids)))
|
||||||
|
|
||||||
dim = len(sres.query_vector)
|
dim = len(sres.query_vector)
|
||||||
|
vector_column = f"q_{dim}_vec"
|
||||||
|
zero_vector = [0.0] * dim
|
||||||
for i in idx:
|
for i in idx:
|
||||||
if sim[i] < similarity_threshold:
|
if sim[i] < similarity_threshold:
|
||||||
break
|
break
|
||||||
@ -401,34 +329,32 @@ class Dealer:
|
|||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
id = sres.ids[i]
|
id = sres.ids[i]
|
||||||
dnm = sres.field[id]["docnm_kwd"]
|
chunk = sres.field[id]
|
||||||
did = sres.field[id]["doc_id"]
|
dnm = chunk["docnm_kwd"]
|
||||||
|
did = chunk["doc_id"]
|
||||||
|
position_list = chunk.get("position_list", "[]")
|
||||||
|
if not position_list:
|
||||||
|
position_list = "[]"
|
||||||
d = {
|
d = {
|
||||||
"chunk_id": id,
|
"chunk_id": id,
|
||||||
"content_ltks": sres.field[id]["content_ltks"],
|
"content_ltks": chunk["content_ltks"],
|
||||||
"content_with_weight": sres.field[id]["content_with_weight"],
|
"content_with_weight": chunk["content_with_weight"],
|
||||||
"doc_id": sres.field[id]["doc_id"],
|
"doc_id": chunk["doc_id"],
|
||||||
"docnm_kwd": dnm,
|
"docnm_kwd": dnm,
|
||||||
"kb_id": sres.field[id]["kb_id"],
|
"kb_id": chunk["kb_id"],
|
||||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
"important_kwd": chunk.get("important_kwd", []),
|
||||||
"img_id": sres.field[id].get("img_id", ""),
|
"image_id": chunk.get("img_id", ""),
|
||||||
"similarity": sim[i],
|
"similarity": sim[i],
|
||||||
"vector_similarity": vsim[i],
|
"vector_similarity": vsim[i],
|
||||||
"term_similarity": tsim[i],
|
"term_similarity": tsim[i],
|
||||||
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
|
"vector": chunk.get(vector_column, zero_vector),
|
||||||
"positions": sres.field[id].get("position_int", "").split("\t")
|
"positions": json.loads(position_list)
|
||||||
}
|
}
|
||||||
if highlight:
|
if highlight:
|
||||||
if id in sres.highlight:
|
if id in sres.highlight:
|
||||||
d["highlight"] = rmSpace(sres.highlight[id])
|
d["highlight"] = rmSpace(sres.highlight[id])
|
||||||
else:
|
else:
|
||||||
d["highlight"] = d["content_with_weight"]
|
d["highlight"] = d["content_with_weight"]
|
||||||
if len(d["positions"]) % 5 == 0:
|
|
||||||
poss = []
|
|
||||||
for i in range(0, len(d["positions"]), 5):
|
|
||||||
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
|
|
||||||
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
|
|
||||||
d["positions"] = poss
|
|
||||||
ranks["chunks"].append(d)
|
ranks["chunks"].append(d)
|
||||||
if dnm not in ranks["doc_aggs"]:
|
if dnm not in ranks["doc_aggs"]:
|
||||||
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
|
||||||
@ -442,39 +368,11 @@ class Dealer:
|
|||||||
return ranks
|
return ranks
|
||||||
|
|
||||||
def sql_retrieval(self, sql, fetch_size=128, format="json"):
|
def sql_retrieval(self, sql, fetch_size=128, format="json"):
|
||||||
from api.settings import chat_logger
|
tbl = self.dataStore.sql(sql, fetch_size, format)
|
||||||
sql = re.sub(r"[ `]+", " ", sql)
|
return tbl
|
||||||
sql = sql.replace("%", "")
|
|
||||||
es_logger.info(f"Get es sql: {sql}")
|
|
||||||
replaces = []
|
|
||||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
|
||||||
fld, v = r.group(1), r.group(3)
|
|
||||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
|
||||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
|
||||||
replaces.append(
|
|
||||||
("{}{}'{}'".format(
|
|
||||||
r.group(1),
|
|
||||||
r.group(2),
|
|
||||||
r.group(3)),
|
|
||||||
match))
|
|
||||||
|
|
||||||
for p, r in replaces:
|
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
||||||
sql = sql.replace(p, r, 1)
|
condition = {"doc_id": doc_id}
|
||||||
chat_logger.info(f"To es: {sql}")
|
res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids)
|
||||||
|
dict_chunks = self.dataStore.getFields(res, fields)
|
||||||
try:
|
return dict_chunks.values()
|
||||||
tbl = self.es.sql(sql, fetch_size, format)
|
|
||||||
return tbl
|
|
||||||
except Exception as e:
|
|
||||||
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
|
|
||||||
return {"error": str(e)}
|
|
||||||
|
|
||||||
def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
|
||||||
s = Search()
|
|
||||||
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
|
||||||
s = s.to_dict()
|
|
||||||
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
|
|
||||||
res = []
|
|
||||||
for index, chunk in enumerate(es_res['hits']['hits']):
|
|
||||||
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
|
||||||
return res
|
|
||||||
|
@ -25,12 +25,13 @@ RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
|||||||
SUBPROCESS_STD_LOG_NAME = "std.log"
|
SUBPROCESS_STD_LOG_NAME = "std.log"
|
||||||
|
|
||||||
ES = get_base_config("es", {})
|
ES = get_base_config("es", {})
|
||||||
|
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||||
AZURE = get_base_config("azure", {})
|
AZURE = get_base_config("azure", {})
|
||||||
S3 = get_base_config("s3", {})
|
S3 = get_base_config("s3", {})
|
||||||
MINIO = decrypt_database_config(name="minio")
|
MINIO = decrypt_database_config(name="minio")
|
||||||
try:
|
try:
|
||||||
REDIS = decrypt_database_config(name="redis")
|
REDIS = decrypt_database_config(name="redis")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
REDIS = {}
|
REDIS = {}
|
||||||
pass
|
pass
|
||||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||||
@ -44,7 +45,7 @@ LoggerFactory.set_directory(
|
|||||||
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
||||||
LoggerFactory.LEVEL = 30
|
LoggerFactory.LEVEL = 30
|
||||||
|
|
||||||
es_logger = getLogger("es")
|
doc_store_logger = getLogger("doc_store")
|
||||||
minio_logger = getLogger("minio")
|
minio_logger = getLogger("minio")
|
||||||
s3_logger = getLogger("s3")
|
s3_logger = getLogger("s3")
|
||||||
azure_logger = getLogger("azure")
|
azure_logger = getLogger("azure")
|
||||||
@ -53,7 +54,7 @@ chunk_logger = getLogger("chunk_logger")
|
|||||||
database_logger = getLogger("database")
|
database_logger = getLogger("database")
|
||||||
|
|
||||||
formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
|
formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
|
||||||
for logger in [es_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
|
for logger in [doc_store_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
for handler in logger.handlers:
|
for handler in logger.handlers:
|
||||||
handler.setFormatter(fmt=formatter)
|
handler.setFormatter(fmt=formatter)
|
||||||
|
@ -31,7 +31,6 @@ from timeit import default_timer as timer
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from elasticsearch_dsl import Q
|
|
||||||
|
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
from api.db.services.dialog_service import keyword_extraction, question_proposal
|
from api.db.services.dialog_service import keyword_extraction, question_proposal
|
||||||
@ -39,8 +38,7 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import TaskService
|
from api.db.services.task_service import TaskService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.settings import retrievaler
|
from api.settings import retrievaler, docStoreConn
|
||||||
from api.utils.file_utils import get_project_base_directory
|
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection
|
||||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
@ -48,7 +46,6 @@ from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as
|
|||||||
from rag.settings import database_logger, SVR_QUEUE_NAME
|
from rag.settings import database_logger, SVR_QUEUE_NAME
|
||||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||||
from rag.utils import rmSpace, num_tokens_from_string
|
from rag.utils import rmSpace, num_tokens_from_string
|
||||||
from rag.utils.es_conn import ELASTICSEARCH
|
|
||||||
from rag.utils.redis_conn import REDIS_CONN, Payload
|
from rag.utils.redis_conn import REDIS_CONN, Payload
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
|
||||||
@ -126,7 +123,7 @@ def collect():
|
|||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
tasks = TaskService.get_tasks(msg["id"])
|
tasks = TaskService.get_tasks(msg["id"])
|
||||||
if not tasks:
|
if not tasks:
|
||||||
cron_logger.warn("{} empty task!".format(msg["id"]))
|
cron_logger.warning("{} empty task!".format(msg["id"]))
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tasks = pd.DataFrame(tasks)
|
tasks = pd.DataFrame(tasks)
|
||||||
@ -187,7 +184,7 @@ def build(row):
|
|||||||
docs = []
|
docs = []
|
||||||
doc = {
|
doc = {
|
||||||
"doc_id": row["doc_id"],
|
"doc_id": row["doc_id"],
|
||||||
"kb_id": [str(row["kb_id"])]
|
"kb_id": str(row["kb_id"])
|
||||||
}
|
}
|
||||||
el = 0
|
el = 0
|
||||||
for ck in cks:
|
for ck in cks:
|
||||||
@ -196,10 +193,14 @@ def build(row):
|
|||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
md5.update((ck["content_with_weight"] +
|
md5.update((ck["content_with_weight"] +
|
||||||
str(d["doc_id"])).encode("utf-8"))
|
str(d["doc_id"])).encode("utf-8"))
|
||||||
d["_id"] = md5.hexdigest()
|
d["id"] = md5.hexdigest()
|
||||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||||
if not d.get("image"):
|
if not d.get("image"):
|
||||||
|
d["img_id"] = ""
|
||||||
|
d["page_num_list"] = json.dumps([])
|
||||||
|
d["position_list"] = json.dumps([])
|
||||||
|
d["top_list"] = json.dumps([])
|
||||||
docs.append(d)
|
docs.append(d)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -211,13 +212,13 @@ def build(row):
|
|||||||
d["image"].save(output_buffer, format='JPEG')
|
d["image"].save(output_buffer, format='JPEG')
|
||||||
|
|
||||||
st = timer()
|
st = timer()
|
||||||
STORAGE_IMPL.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
||||||
el += timer() - st
|
el += timer() - st
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cron_logger.error(str(e))
|
cron_logger.error(str(e))
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
||||||
del d["image"]
|
del d["image"]
|
||||||
docs.append(d)
|
docs.append(d)
|
||||||
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
|
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
|
||||||
@ -245,12 +246,9 @@ def build(row):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def init_kb(row):
|
def init_kb(row, vector_size: int):
|
||||||
idxnm = search.index_name(row["tenant_id"])
|
idxnm = search.index_name(row["tenant_id"])
|
||||||
if ELASTICSEARCH.indexExist(idxnm):
|
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
||||||
return
|
|
||||||
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
|
||||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
|
||||||
|
|
||||||
|
|
||||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||||
@ -288,17 +286,20 @@ def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
cnts) if len(tts) == len(cnts) else cnts
|
cnts) if len(tts) == len(cnts) else cnts
|
||||||
|
|
||||||
assert len(vects) == len(docs)
|
assert len(vects) == len(docs)
|
||||||
|
vector_size = 0
|
||||||
for i, d in enumerate(docs):
|
for i, d in enumerate(docs):
|
||||||
v = vects[i].tolist()
|
v = vects[i].tolist()
|
||||||
|
vector_size = len(v)
|
||||||
d["q_%d_vec" % len(v)] = v
|
d["q_%d_vec" % len(v)] = v
|
||||||
return tk_count
|
return tk_count, vector_size
|
||||||
|
|
||||||
|
|
||||||
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||||
vts, _ = embd_mdl.encode(["ok"])
|
vts, _ = embd_mdl.encode(["ok"])
|
||||||
vctr_nm = "q_%d_vec" % len(vts[0])
|
vector_size = len(vts[0])
|
||||||
|
vctr_nm = "q_%d_vec" % vector_size
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
|
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
|
||||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||||
|
|
||||||
raptor = Raptor(
|
raptor = Raptor(
|
||||||
@ -323,7 +324,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|||||||
d = copy.deepcopy(doc)
|
d = copy.deepcopy(doc)
|
||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
md5.update((content + str(d["doc_id"])).encode("utf-8"))
|
md5.update((content + str(d["doc_id"])).encode("utf-8"))
|
||||||
d["_id"] = md5.hexdigest()
|
d["id"] = md5.hexdigest()
|
||||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||||
d[vctr_nm] = vctr.tolist()
|
d[vctr_nm] = vctr.tolist()
|
||||||
@ -332,7 +333,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|||||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||||
res.append(d)
|
res.append(d)
|
||||||
tk_count += num_tokens_from_string(content)
|
tk_count += num_tokens_from_string(content)
|
||||||
return res, tk_count
|
return res, tk_count, vector_size
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -352,7 +353,7 @@ def main():
|
|||||||
if r.get("task_type", "") == "raptor":
|
if r.get("task_type", "") == "raptor":
|
||||||
try:
|
try:
|
||||||
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
|
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
|
||||||
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback)
|
cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
callback(-1, msg=str(e))
|
callback(-1, msg=str(e))
|
||||||
cron_logger.error(str(e))
|
cron_logger.error(str(e))
|
||||||
@ -373,7 +374,7 @@ def main():
|
|||||||
len(cks))
|
len(cks))
|
||||||
st = timer()
|
st = timer()
|
||||||
try:
|
try:
|
||||||
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
callback(-1, "Embedding error:{}".format(str(e)))
|
callback(-1, "Embedding error:{}".format(str(e)))
|
||||||
cron_logger.error(str(e))
|
cron_logger.error(str(e))
|
||||||
@ -381,26 +382,25 @@ def main():
|
|||||||
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||||
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
||||||
|
|
||||||
init_kb(r)
|
# cron_logger.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
|
||||||
chunk_count = len(set([c["_id"] for c in cks]))
|
init_kb(r, vector_size)
|
||||||
|
chunk_count = len(set([c["id"] for c in cks]))
|
||||||
st = timer()
|
st = timer()
|
||||||
es_r = ""
|
es_r = ""
|
||||||
es_bulk_size = 4
|
es_bulk_size = 4
|
||||||
for b in range(0, len(cks), es_bulk_size):
|
for b in range(0, len(cks), es_bulk_size):
|
||||||
es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]))
|
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
||||||
if b % 128 == 0:
|
if b % 128 == 0:
|
||||||
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
||||||
|
|
||||||
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||||
if es_r:
|
if es_r:
|
||||||
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
|
||||||
ELASTICSEARCH.deleteByQuery(
|
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||||
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
cron_logger.error('Insert chunk error: ' + str(es_r))
|
||||||
cron_logger.error(str(es_r))
|
|
||||||
else:
|
else:
|
||||||
if TaskService.do_cancel(r["id"]):
|
if TaskService.do_cancel(r["id"]):
|
||||||
ELASTICSEARCH.deleteByQuery(
|
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||||
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
|
|
||||||
continue
|
continue
|
||||||
callback(1., "Done!")
|
callback(1., "Done!")
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
|
251
rag/utils/doc_store_conn.py
Normal file
251
rag/utils/doc_store_conn.py
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
DEFAULT_MATCH_VECTOR_TOPN = 10
|
||||||
|
DEFAULT_MATCH_SPARSE_TOPN = 10
|
||||||
|
VEC = Union[list, np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SparseVector:
|
||||||
|
indices: list[int]
|
||||||
|
values: Union[list[float], list[int], None] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert (self.values is None) or (len(self.indices) == len(self.values))
|
||||||
|
|
||||||
|
def to_dict_old(self):
|
||||||
|
d = {"indices": self.indices}
|
||||||
|
if self.values is not None:
|
||||||
|
d["values"] = self.values
|
||||||
|
return d
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
if self.values is None:
|
||||||
|
raise ValueError("SparseVector.values is None")
|
||||||
|
result = {}
|
||||||
|
for i, v in zip(self.indices, self.values):
|
||||||
|
result[str(i)] = v
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(d):
|
||||||
|
return SparseVector(d["indices"], d.get("values"))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MatchTextExpr(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fields: str,
|
||||||
|
matching_text: str,
|
||||||
|
topn: int,
|
||||||
|
extra_options: dict = dict(),
|
||||||
|
):
|
||||||
|
self.fields = fields
|
||||||
|
self.matching_text = matching_text
|
||||||
|
self.topn = topn
|
||||||
|
self.extra_options = extra_options
|
||||||
|
|
||||||
|
|
||||||
|
class MatchDenseExpr(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_column_name: str,
|
||||||
|
embedding_data: VEC,
|
||||||
|
embedding_data_type: str,
|
||||||
|
distance_type: str,
|
||||||
|
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
||||||
|
extra_options: dict = dict(),
|
||||||
|
):
|
||||||
|
self.vector_column_name = vector_column_name
|
||||||
|
self.embedding_data = embedding_data
|
||||||
|
self.embedding_data_type = embedding_data_type
|
||||||
|
self.distance_type = distance_type
|
||||||
|
self.topn = topn
|
||||||
|
self.extra_options = extra_options
|
||||||
|
|
||||||
|
|
||||||
|
class MatchSparseExpr(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_column_name: str,
|
||||||
|
sparse_data: SparseVector | dict,
|
||||||
|
distance_type: str,
|
||||||
|
topn: int,
|
||||||
|
opt_params: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
self.vector_column_name = vector_column_name
|
||||||
|
self.sparse_data = sparse_data
|
||||||
|
self.distance_type = distance_type
|
||||||
|
self.topn = topn
|
||||||
|
self.opt_params = opt_params
|
||||||
|
|
||||||
|
|
||||||
|
class MatchTensorExpr(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
column_name: str,
|
||||||
|
query_data: VEC,
|
||||||
|
query_data_type: str,
|
||||||
|
topn: int,
|
||||||
|
extra_option: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
self.column_name = column_name
|
||||||
|
self.query_data = query_data
|
||||||
|
self.query_data_type = query_data_type
|
||||||
|
self.topn = topn
|
||||||
|
self.extra_option = extra_option
|
||||||
|
|
||||||
|
|
||||||
|
class FusionExpr(ABC):
|
||||||
|
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
|
||||||
|
self.method = method
|
||||||
|
self.topn = topn
|
||||||
|
self.fusion_params = fusion_params
|
||||||
|
|
||||||
|
|
||||||
|
MatchExpr = Union[
|
||||||
|
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OrderByExpr(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
self.fields = list()
|
||||||
|
def asc(self, field: str):
|
||||||
|
self.fields.append((field, 0))
|
||||||
|
return self
|
||||||
|
def desc(self, field: str):
|
||||||
|
self.fields.append((field, 1))
|
||||||
|
return self
|
||||||
|
def fields(self):
|
||||||
|
return self.fields
|
||||||
|
|
||||||
|
class DocStoreConnection(ABC):
|
||||||
|
"""
|
||||||
|
Database operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dbType(self) -> str:
|
||||||
|
"""
|
||||||
|
Return the type of the database.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def health(self) -> dict:
|
||||||
|
"""
|
||||||
|
Return the health status of the database.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Table operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||||
|
"""
|
||||||
|
Create an index with given name
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||||
|
"""
|
||||||
|
Delete an index with given name
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an index with given name exists
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
"""
|
||||||
|
CRUD operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
|
||||||
|
) -> list[dict] | pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||||
|
"""
|
||||||
|
Get single chunk with given id
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Update or insert a bulk of rows
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||||
|
"""
|
||||||
|
Update rows with given conjunctive equivalent filtering condition
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||||
|
"""
|
||||||
|
Delete rows with given conjunctive equivalent filtering condition
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helper functions for search result
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getTotal(self, res):
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getChunkIds(self, res):
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def getAggregation(self, res, fieldnm: str):
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
"""
|
||||||
|
SQL
|
||||||
|
"""
|
||||||
|
@abstractmethod
|
||||||
|
def sql(sql: str, fetch_size: int, format: str):
|
||||||
|
"""
|
||||||
|
Run the sql generated by text-to-sql
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented")
|
@ -1,29 +1,29 @@
|
|||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import copy
|
import os
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
import elasticsearch
|
import elasticsearch
|
||||||
from elastic_transport import ConnectionTimeout
|
import copy
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
from elasticsearch_dsl import UpdateByQuery, Search, Index
|
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
||||||
from rag.settings import es_logger
|
from elastic_transport import ConnectionTimeout
|
||||||
|
from rag.settings import doc_store_logger
|
||||||
from rag import settings
|
from rag import settings
|
||||||
from rag.utils import singleton
|
from rag.utils import singleton
|
||||||
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
import polars as pl
|
||||||
|
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
||||||
|
from rag.nlp import is_english, rag_tokenizer
|
||||||
|
|
||||||
es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
|
doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class ESConnection:
|
class ESConnection(DocStoreConnection):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.info = {}
|
self.info = {}
|
||||||
self.conn()
|
|
||||||
self.idxnm = settings.ES.get("index_name", "")
|
|
||||||
if not self.es.ping():
|
|
||||||
raise Exception("Can't connect to ES cluster")
|
|
||||||
|
|
||||||
def conn(self):
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
self.es = Elasticsearch(
|
self.es = Elasticsearch(
|
||||||
@ -34,390 +34,317 @@ class ESConnection:
|
|||||||
)
|
)
|
||||||
if self.es:
|
if self.es:
|
||||||
self.info = self.es.info()
|
self.info = self.es.info()
|
||||||
es_logger.info("Connect to es.")
|
doc_store_logger.info("Connect to es.")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
es_logger.error("Fail to connect to es: " + str(e))
|
doc_store_logger.error("Fail to connect to es: " + str(e))
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
if not self.es.ping():
|
||||||
def version(self):
|
raise Exception("Can't connect to ES cluster")
|
||||||
v = self.info.get("version", {"number": "5.6"})
|
v = self.info.get("version", {"number": "5.6"})
|
||||||
v = v["number"].split(".")[0]
|
v = v["number"].split(".")[0]
|
||||||
return int(v) >= 7
|
if int(v) < 8:
|
||||||
|
raise Exception(f"ES version must be greater than or equal to 8, current version: {v}")
|
||||||
|
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
|
||||||
|
if not os.path.exists(fp_mapping):
|
||||||
|
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||||
|
self.mapping = json.load(open(fp_mapping, "r"))
|
||||||
|
|
||||||
def health(self):
|
"""
|
||||||
return dict(self.es.cluster.health())
|
Database operations
|
||||||
|
"""
|
||||||
|
def dbType(self) -> str:
|
||||||
|
return "elasticsearch"
|
||||||
|
|
||||||
def upsert(self, df, idxnm=""):
|
def health(self) -> dict:
|
||||||
res = []
|
return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
|
||||||
for d in df:
|
|
||||||
id = d["id"]
|
|
||||||
del d["id"]
|
|
||||||
d = {"doc": d, "doc_as_upsert": "true"}
|
|
||||||
T = False
|
|
||||||
for _ in range(10):
|
|
||||||
try:
|
|
||||||
if not self.version():
|
|
||||||
r = self.es.update(
|
|
||||||
index=(
|
|
||||||
self.idxnm if not idxnm else idxnm),
|
|
||||||
body=d,
|
|
||||||
id=id,
|
|
||||||
doc_type="doc",
|
|
||||||
refresh=True,
|
|
||||||
retry_on_conflict=100)
|
|
||||||
else:
|
|
||||||
r = self.es.update(
|
|
||||||
index=(
|
|
||||||
self.idxnm if not idxnm else idxnm),
|
|
||||||
body=d,
|
|
||||||
id=id,
|
|
||||||
refresh=True,
|
|
||||||
retry_on_conflict=100)
|
|
||||||
es_logger.info("Successfully upsert: %s" % id)
|
|
||||||
T = True
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.warning("Fail to index: " +
|
|
||||||
json.dumps(d, ensure_ascii=False) + str(e))
|
|
||||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
|
||||||
time.sleep(3)
|
|
||||||
continue
|
|
||||||
self.conn()
|
|
||||||
T = False
|
|
||||||
|
|
||||||
if not T:
|
"""
|
||||||
res.append(d)
|
Table operations
|
||||||
es_logger.error(
|
"""
|
||||||
"Fail to index: " +
|
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||||
re.sub(
|
if self.indexExist(indexName, knowledgebaseId):
|
||||||
"[\r\n]",
|
|
||||||
"",
|
|
||||||
json.dumps(
|
|
||||||
d,
|
|
||||||
ensure_ascii=False)))
|
|
||||||
d["id"] = id
|
|
||||||
d["_index"] = self.idxnm
|
|
||||||
|
|
||||||
if not res:
|
|
||||||
return True
|
return True
|
||||||
return False
|
try:
|
||||||
|
from elasticsearch.client import IndicesClient
|
||||||
|
return IndicesClient(self.es).create(index=indexName,
|
||||||
|
settings=self.mapping["settings"],
|
||||||
|
mappings=self.mapping["mappings"])
|
||||||
|
except Exception as e:
|
||||||
|
doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e)))
|
||||||
|
|
||||||
def bulk(self, df, idx_nm=None):
|
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||||
ids, acts = {}, []
|
try:
|
||||||
for d in df:
|
return self.es.indices.delete(indexName, allow_no_indices=True)
|
||||||
id = d["id"] if "id" in d else d["_id"]
|
except Exception as e:
|
||||||
ids[id] = copy.deepcopy(d)
|
doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e)))
|
||||||
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
|
|
||||||
if "id" in d:
|
|
||||||
del d["id"]
|
|
||||||
if "_id" in d:
|
|
||||||
del d["_id"]
|
|
||||||
acts.append(
|
|
||||||
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
|
|
||||||
acts.append({"doc": d, "doc_as_upsert": "true"})
|
|
||||||
|
|
||||||
res = []
|
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||||
for _ in range(100):
|
s = Index(indexName, self.es)
|
||||||
try:
|
|
||||||
if elasticsearch.__version__[0] < 8:
|
|
||||||
r = self.es.bulk(
|
|
||||||
index=(
|
|
||||||
self.idxnm if not idx_nm else idx_nm),
|
|
||||||
body=acts,
|
|
||||||
refresh=False,
|
|
||||||
timeout="600s")
|
|
||||||
else:
|
|
||||||
r = self.es.bulk(index=(self.idxnm if not idx_nm else
|
|
||||||
idx_nm), operations=acts,
|
|
||||||
refresh=False, timeout="600s")
|
|
||||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
|
||||||
return res
|
|
||||||
|
|
||||||
for it in r["items"]:
|
|
||||||
if "error" in it["update"]:
|
|
||||||
res.append(str(it["update"]["_id"]) +
|
|
||||||
":" + str(it["update"]["error"]))
|
|
||||||
|
|
||||||
return res
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.warn("Fail to bulk: " + str(e))
|
|
||||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
|
||||||
time.sleep(3)
|
|
||||||
continue
|
|
||||||
self.conn()
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def bulk4script(self, df):
|
|
||||||
ids, acts = {}, []
|
|
||||||
for d in df:
|
|
||||||
id = d["id"]
|
|
||||||
ids[id] = copy.deepcopy(d["raw"])
|
|
||||||
acts.append({"update": {"_id": id, "_index": self.idxnm}})
|
|
||||||
acts.append(d["script"])
|
|
||||||
es_logger.info("bulk upsert: %s" % id)
|
|
||||||
|
|
||||||
res = []
|
|
||||||
for _ in range(10):
|
|
||||||
try:
|
|
||||||
if not self.version():
|
|
||||||
r = self.es.bulk(
|
|
||||||
index=self.idxnm,
|
|
||||||
body=acts,
|
|
||||||
refresh=False,
|
|
||||||
timeout="600s",
|
|
||||||
doc_type="doc")
|
|
||||||
else:
|
|
||||||
r = self.es.bulk(
|
|
||||||
index=self.idxnm,
|
|
||||||
body=acts,
|
|
||||||
refresh=False,
|
|
||||||
timeout="600s")
|
|
||||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
|
||||||
return res
|
|
||||||
|
|
||||||
for it in r["items"]:
|
|
||||||
if "error" in it["update"]:
|
|
||||||
res.append(str(it["update"]["_id"]))
|
|
||||||
|
|
||||||
return res
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.warning("Fail to bulk: " + str(e))
|
|
||||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
|
||||||
time.sleep(3)
|
|
||||||
continue
|
|
||||||
self.conn()
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def rm(self, d):
|
|
||||||
for _ in range(10):
|
|
||||||
try:
|
|
||||||
if not self.version():
|
|
||||||
r = self.es.delete(
|
|
||||||
index=self.idxnm,
|
|
||||||
id=d["id"],
|
|
||||||
doc_type="doc",
|
|
||||||
refresh=True)
|
|
||||||
else:
|
|
||||||
r = self.es.delete(
|
|
||||||
index=self.idxnm,
|
|
||||||
id=d["id"],
|
|
||||||
refresh=True,
|
|
||||||
doc_type="_doc")
|
|
||||||
es_logger.info("Remove %s" % d["id"])
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.warn("Fail to delete: " + str(d) + str(e))
|
|
||||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
|
||||||
time.sleep(3)
|
|
||||||
continue
|
|
||||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
|
||||||
return True
|
|
||||||
self.conn()
|
|
||||||
|
|
||||||
es_logger.error("Fail to delete: " + str(d))
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def search(self, q, idxnms=None, src=False, timeout="2s"):
|
|
||||||
if not isinstance(q, dict):
|
|
||||||
q = Search().query(q).to_dict()
|
|
||||||
if isinstance(idxnms, str):
|
|
||||||
idxnms = idxnms.split(",")
|
|
||||||
for i in range(3):
|
|
||||||
try:
|
|
||||||
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
|
|
||||||
body=q,
|
|
||||||
timeout=timeout,
|
|
||||||
# search_type="dfs_query_then_fetch",
|
|
||||||
track_total_hits=True,
|
|
||||||
_source=src)
|
|
||||||
if str(res.get("timed_out", "")).lower() == "true":
|
|
||||||
raise Exception("Es Timeout.")
|
|
||||||
return res
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.error(
|
|
||||||
"ES search exception: " +
|
|
||||||
str(e) +
|
|
||||||
"【Q】:" +
|
|
||||||
str(q))
|
|
||||||
if str(e).find("Timeout") > 0:
|
|
||||||
continue
|
|
||||||
raise e
|
|
||||||
es_logger.error("ES search timeout for 3 times!")
|
|
||||||
raise Exception("ES search timeout.")
|
|
||||||
|
|
||||||
def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
|
|
||||||
for i in range(3):
|
|
||||||
try:
|
|
||||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
|
|
||||||
return res
|
|
||||||
except ConnectionTimeout as e:
|
|
||||||
es_logger.error("Timeout【Q】:" + sql)
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
es_logger.error("ES search timeout for 3 times!")
|
|
||||||
raise ConnectionTimeout()
|
|
||||||
|
|
||||||
|
|
||||||
def get(self, doc_id, idxnm=None):
|
|
||||||
for i in range(3):
|
|
||||||
try:
|
|
||||||
res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
|
|
||||||
id=doc_id)
|
|
||||||
if str(res.get("timed_out", "")).lower() == "true":
|
|
||||||
raise Exception("Es Timeout.")
|
|
||||||
return res
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.error(
|
|
||||||
"ES get exception: " +
|
|
||||||
str(e) +
|
|
||||||
"【Q】:" +
|
|
||||||
doc_id)
|
|
||||||
if str(e).find("Timeout") > 0:
|
|
||||||
continue
|
|
||||||
raise e
|
|
||||||
es_logger.error("ES search timeout for 3 times!")
|
|
||||||
raise Exception("ES search timeout.")
|
|
||||||
|
|
||||||
def updateByQuery(self, q, d):
|
|
||||||
ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
|
|
||||||
scripts = ""
|
|
||||||
for k, v in d.items():
|
|
||||||
scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
|
|
||||||
ubq = ubq.script(source=scripts, params=d)
|
|
||||||
ubq = ubq.params(refresh=False)
|
|
||||||
ubq = ubq.params(slices=5)
|
|
||||||
ubq = ubq.params(conflicts="proceed")
|
|
||||||
for i in range(3):
|
|
||||||
try:
|
|
||||||
r = ubq.execute()
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.error("ES updateByQuery exception: " +
|
|
||||||
str(e) + "【Q】:" + str(q.to_dict()))
|
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
|
||||||
continue
|
|
||||||
self.conn()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def updateScriptByQuery(self, q, scripts, idxnm=None):
|
|
||||||
ubq = UpdateByQuery(
|
|
||||||
index=self.idxnm if not idxnm else idxnm).using(
|
|
||||||
self.es).query(q)
|
|
||||||
ubq = ubq.script(source=scripts)
|
|
||||||
ubq = ubq.params(refresh=True)
|
|
||||||
ubq = ubq.params(slices=5)
|
|
||||||
ubq = ubq.params(conflicts="proceed")
|
|
||||||
for i in range(3):
|
|
||||||
try:
|
|
||||||
r = ubq.execute()
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.error("ES updateByQuery exception: " +
|
|
||||||
str(e) + "【Q】:" + str(q.to_dict()))
|
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
|
||||||
continue
|
|
||||||
self.conn()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def deleteByQuery(self, query, idxnm=""):
|
|
||||||
for i in range(3):
|
|
||||||
try:
|
|
||||||
r = self.es.delete_by_query(
|
|
||||||
index=idxnm if idxnm else self.idxnm,
|
|
||||||
refresh = True,
|
|
||||||
body=Search().query(query).to_dict())
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.error("ES updateByQuery deleteByQuery: " +
|
|
||||||
str(e) + "【Q】:" + str(query.to_dict()))
|
|
||||||
if str(e).find("NotFoundError") > 0: return True
|
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def update(self, id, script, routing=None):
|
|
||||||
for i in range(3):
|
|
||||||
try:
|
|
||||||
if not self.version():
|
|
||||||
r = self.es.update(
|
|
||||||
index=self.idxnm,
|
|
||||||
id=id,
|
|
||||||
body=json.dumps(
|
|
||||||
script,
|
|
||||||
ensure_ascii=False),
|
|
||||||
doc_type="doc",
|
|
||||||
routing=routing,
|
|
||||||
refresh=False)
|
|
||||||
else:
|
|
||||||
r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
|
|
||||||
routing=routing, refresh=False) # , doc_type="_doc")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.error(
|
|
||||||
"ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
|
||||||
json.dumps(script, ensure_ascii=False))
|
|
||||||
if str(e).find("Timeout") > 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def indexExist(self, idxnm):
|
|
||||||
s = Index(idxnm if idxnm else self.idxnm, self.es)
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
try:
|
try:
|
||||||
return s.exists()
|
return s.exists()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
es_logger.error("ES updateByQuery indexExist: " + str(e))
|
doc_store_logger.error("ES indexExist: " + str(e))
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def docExist(self, docid, idxnm=None):
|
"""
|
||||||
|
CRUD operations
|
||||||
|
"""
|
||||||
|
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||||
|
"""
|
||||||
|
if isinstance(indexNames, str):
|
||||||
|
indexNames = indexNames.split(",")
|
||||||
|
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||||
|
assert "_id" not in condition
|
||||||
|
s = Search()
|
||||||
|
bqry = None
|
||||||
|
vector_similarity_weight = 0.5
|
||||||
|
for m in matchExprs:
|
||||||
|
if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
|
||||||
|
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
|
||||||
|
weights = m.fusion_params["weights"]
|
||||||
|
vector_similarity_weight = float(weights.split(",")[1])
|
||||||
|
for m in matchExprs:
|
||||||
|
if isinstance(m, MatchTextExpr):
|
||||||
|
minimum_should_match = "0%"
|
||||||
|
if "minimum_should_match" in m.extra_options:
|
||||||
|
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
|
||||||
|
bqry = Q("bool",
|
||||||
|
must=Q("query_string", fields=m.fields,
|
||||||
|
type="best_fields", query=m.matching_text,
|
||||||
|
minimum_should_match = minimum_should_match,
|
||||||
|
boost=1),
|
||||||
|
boost = 1.0 - vector_similarity_weight,
|
||||||
|
)
|
||||||
|
if condition:
|
||||||
|
for k, v in condition.items():
|
||||||
|
if not isinstance(k, str) or not v:
|
||||||
|
continue
|
||||||
|
if isinstance(v, list):
|
||||||
|
bqry.filter.append(Q("terms", **{k: v}))
|
||||||
|
elif isinstance(v, str) or isinstance(v, int):
|
||||||
|
bqry.filter.append(Q("term", **{k: v}))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||||
|
elif isinstance(m, MatchDenseExpr):
|
||||||
|
assert(bqry is not None)
|
||||||
|
similarity = 0.0
|
||||||
|
if "similarity" in m.extra_options:
|
||||||
|
similarity = m.extra_options["similarity"]
|
||||||
|
s = s.knn(m.vector_column_name,
|
||||||
|
m.topn,
|
||||||
|
m.topn * 2,
|
||||||
|
query_vector = list(m.embedding_data),
|
||||||
|
filter = bqry.to_dict(),
|
||||||
|
similarity = similarity,
|
||||||
|
)
|
||||||
|
if matchExprs:
|
||||||
|
s.query = bqry
|
||||||
|
for field in highlightFields:
|
||||||
|
s = s.highlight(field)
|
||||||
|
|
||||||
|
if orderBy:
|
||||||
|
orders = list()
|
||||||
|
for field, order in orderBy.fields:
|
||||||
|
order = "asc" if order == 0 else "desc"
|
||||||
|
orders.append({field: {"order": order, "unmapped_type": "float",
|
||||||
|
"mode": "avg", "numeric_type": "double"}})
|
||||||
|
s = s.sort(*orders)
|
||||||
|
|
||||||
|
if limit > 0:
|
||||||
|
s = s[offset:limit]
|
||||||
|
q = s.to_dict()
|
||||||
|
doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q))
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
try:
|
try:
|
||||||
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
|
res = self.es.search(index=indexNames,
|
||||||
id=docid)
|
body=q,
|
||||||
|
timeout="600s",
|
||||||
|
# search_type="dfs_query_then_fetch",
|
||||||
|
track_total_hits=True,
|
||||||
|
_source=True)
|
||||||
|
if str(res.get("timed_out", "")).lower() == "true":
|
||||||
|
raise Exception("Es Timeout.")
|
||||||
|
doc_store_logger.info("ESConnection.search res: " + str(res))
|
||||||
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
es_logger.error("ES Doc Exist: " + str(e))
|
doc_store_logger.error(
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
"ES search exception: " +
|
||||||
|
str(e) +
|
||||||
|
"\n[Q]: " +
|
||||||
|
str(q))
|
||||||
|
if str(e).find("Timeout") > 0:
|
||||||
continue
|
continue
|
||||||
|
raise e
|
||||||
|
doc_store_logger.error("ES search timeout for 3 times!")
|
||||||
|
raise Exception("ES search timeout.")
|
||||||
|
|
||||||
|
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||||
|
for i in range(3):
|
||||||
|
try:
|
||||||
|
res = self.es.get(index=(indexName),
|
||||||
|
id=chunkId, source=True,)
|
||||||
|
if str(res.get("timed_out", "")).lower() == "true":
|
||||||
|
raise Exception("Es Timeout.")
|
||||||
|
if not res.get("found"):
|
||||||
|
return None
|
||||||
|
chunk = res["_source"]
|
||||||
|
chunk["id"] = chunkId
|
||||||
|
return chunk
|
||||||
|
except Exception as e:
|
||||||
|
doc_store_logger.error(
|
||||||
|
"ES get exception: " +
|
||||||
|
str(e) +
|
||||||
|
"[Q]: " +
|
||||||
|
chunkId)
|
||||||
|
if str(e).find("Timeout") > 0:
|
||||||
|
continue
|
||||||
|
raise e
|
||||||
|
doc_store_logger.error("ES search timeout for 3 times!")
|
||||||
|
raise Exception("ES search timeout.")
|
||||||
|
|
||||||
|
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
||||||
|
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
||||||
|
operations = []
|
||||||
|
for d in documents:
|
||||||
|
assert "_id" not in d
|
||||||
|
assert "id" in d
|
||||||
|
d_copy = copy.deepcopy(d)
|
||||||
|
meta_id = d_copy["id"]
|
||||||
|
del d_copy["id"]
|
||||||
|
operations.append(
|
||||||
|
{"index": {"_index": indexName, "_id": meta_id}})
|
||||||
|
operations.append(d_copy)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for _ in range(100):
|
||||||
|
try:
|
||||||
|
r = self.es.bulk(index=(indexName), operations=operations,
|
||||||
|
refresh=False, timeout="600s")
|
||||||
|
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||||
|
return res
|
||||||
|
|
||||||
|
for item in r["items"]:
|
||||||
|
for action in ["create", "delete", "index", "update"]:
|
||||||
|
if action in item and "error" in item[action]:
|
||||||
|
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
doc_store_logger.warning("Fail to bulk: " + str(e))
|
||||||
|
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||||
|
time.sleep(3)
|
||||||
|
continue
|
||||||
|
return res
|
||||||
|
|
||||||
|
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||||
|
doc = copy.deepcopy(newValue)
|
||||||
|
del doc['id']
|
||||||
|
if "id" in condition and isinstance(condition["id"], str):
|
||||||
|
# update specific single document
|
||||||
|
chunkId = condition["id"]
|
||||||
|
for i in range(3):
|
||||||
|
try:
|
||||||
|
self.es.update(index=indexName, id=chunkId, doc=doc)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
doc_store_logger.error(
|
||||||
|
"ES update exception: " + str(e) + " id:" + str(id) +
|
||||||
|
json.dumps(newValue, ensure_ascii=False))
|
||||||
|
if str(e).find("Timeout") > 0:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# update unspecific maybe-multiple documents
|
||||||
|
bqry = Q("bool")
|
||||||
|
for k, v in condition.items():
|
||||||
|
if not isinstance(k, str) or not v:
|
||||||
|
continue
|
||||||
|
if isinstance(v, list):
|
||||||
|
bqry.filter.append(Q("terms", **{k: v}))
|
||||||
|
elif isinstance(v, str) or isinstance(v, int):
|
||||||
|
bqry.filter.append(Q("term", **{k: v}))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||||
|
scripts = []
|
||||||
|
for k, v in newValue.items():
|
||||||
|
if not isinstance(k, str) or not v:
|
||||||
|
continue
|
||||||
|
if isinstance(v, str):
|
||||||
|
scripts.append(f"ctx._source.{k} = '{v}'")
|
||||||
|
elif isinstance(v, int):
|
||||||
|
scripts.append(f"ctx._source.{k} = {v}")
|
||||||
|
else:
|
||||||
|
raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||||
|
ubq = UpdateByQuery(
|
||||||
|
index=indexName).using(
|
||||||
|
self.es).query(bqry)
|
||||||
|
ubq = ubq.script(source="; ".join(scripts))
|
||||||
|
ubq = ubq.params(refresh=True)
|
||||||
|
ubq = ubq.params(slices=5)
|
||||||
|
ubq = ubq.params(conflicts="proceed")
|
||||||
|
for i in range(3):
|
||||||
|
try:
|
||||||
|
_ = ubq.execute()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
doc_store_logger.error("ES update exception: " +
|
||||||
|
str(e) + "[Q]:" + str(bqry.to_dict()))
|
||||||
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||||
|
continue
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def createIdx(self, idxnm, mapping):
|
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||||
try:
|
qry = None
|
||||||
if elasticsearch.__version__[0] < 8:
|
assert "_id" not in condition
|
||||||
return self.es.indices.create(idxnm, body=mapping)
|
if "id" in condition:
|
||||||
from elasticsearch.client import IndicesClient
|
chunk_ids = condition["id"]
|
||||||
return IndicesClient(self.es).create(index=idxnm,
|
if not isinstance(chunk_ids, list):
|
||||||
settings=mapping["settings"],
|
chunk_ids = [chunk_ids]
|
||||||
mappings=mapping["mappings"])
|
qry = Q("ids", values=chunk_ids)
|
||||||
except Exception as e:
|
else:
|
||||||
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
qry = Q("bool")
|
||||||
|
for k, v in condition.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
qry.must.append(Q("terms", **{k: v}))
|
||||||
|
elif isinstance(v, str) or isinstance(v, int):
|
||||||
|
qry.must.append(Q("term", **{k: v}))
|
||||||
|
else:
|
||||||
|
raise Exception("Condition value must be int, str or list.")
|
||||||
|
doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
|
||||||
|
for _ in range(10):
|
||||||
|
try:
|
||||||
|
res = self.es.delete_by_query(
|
||||||
|
index=indexName,
|
||||||
|
body = Search().query(qry).to_dict(),
|
||||||
|
refresh=True)
|
||||||
|
return res["deleted"]
|
||||||
|
except Exception as e:
|
||||||
|
doc_store_logger.warning("Fail to delete: " + str(filter) + str(e))
|
||||||
|
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||||
|
time.sleep(3)
|
||||||
|
continue
|
||||||
|
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||||
|
return 0
|
||||||
|
return 0
|
||||||
|
|
||||||
def deleteIdx(self, idxnm):
|
|
||||||
try:
|
|
||||||
return self.es.indices.delete(idxnm, allow_no_indices=True)
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helper functions for search result
|
||||||
|
"""
|
||||||
def getTotal(self, res):
|
def getTotal(self, res):
|
||||||
if isinstance(res["hits"]["total"], type({})):
|
if isinstance(res["hits"]["total"], type({})):
|
||||||
return res["hits"]["total"]["value"]
|
return res["hits"]["total"]["value"]
|
||||||
return res["hits"]["total"]
|
return res["hits"]["total"]
|
||||||
|
|
||||||
def getDocIds(self, res):
|
def getChunkIds(self, res):
|
||||||
return [d["_id"] for d in res["hits"]["hits"]]
|
return [d["_id"] for d in res["hits"]["hits"]]
|
||||||
|
|
||||||
def getSource(self, res):
|
def __getSource(self, res):
|
||||||
rr = []
|
rr = []
|
||||||
for d in res["hits"]["hits"]:
|
for d in res["hits"]["hits"]:
|
||||||
d["_source"]["id"] = d["_id"]
|
d["_source"]["id"] = d["_id"]
|
||||||
@ -425,40 +352,89 @@ class ESConnection:
|
|||||||
rr.append(d["_source"])
|
rr.append(d["_source"])
|
||||||
return rr
|
return rr
|
||||||
|
|
||||||
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
||||||
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
res_fields = {}
|
||||||
for _ in range(100):
|
if not fields:
|
||||||
|
return {}
|
||||||
|
for d in self.__getSource(res):
|
||||||
|
m = {n: d.get(n) for n in fields if d.get(n) is not None}
|
||||||
|
for n, v in m.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
m[n] = v
|
||||||
|
continue
|
||||||
|
if not isinstance(v, str):
|
||||||
|
m[n] = str(m[n])
|
||||||
|
# if n.find("tks") > 0:
|
||||||
|
# m[n] = rmSpace(m[n])
|
||||||
|
|
||||||
|
if m:
|
||||||
|
res_fields[d["id"]] = m
|
||||||
|
return res_fields
|
||||||
|
|
||||||
|
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
||||||
|
ans = {}
|
||||||
|
for d in res["hits"]["hits"]:
|
||||||
|
hlts = d.get("highlight")
|
||||||
|
if not hlts:
|
||||||
|
continue
|
||||||
|
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
||||||
|
if not is_english(txt.split(" ")):
|
||||||
|
ans[d["_id"]] = txt
|
||||||
|
continue
|
||||||
|
|
||||||
|
txt = d["_source"][fieldnm]
|
||||||
|
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
|
||||||
|
txts = []
|
||||||
|
for t in re.split(r"[.?!;\n]", txt):
|
||||||
|
for w in keywords:
|
||||||
|
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
|
||||||
|
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
|
||||||
|
continue
|
||||||
|
txts.append(t)
|
||||||
|
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def getAggregation(self, res, fieldnm: str):
|
||||||
|
agg_field = "aggs_" + fieldnm
|
||||||
|
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||||
|
return list()
|
||||||
|
bkts = res["aggregations"][agg_field]["buckets"]
|
||||||
|
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
SQL
|
||||||
|
"""
|
||||||
|
def sql(self, sql: str, fetch_size: int, format: str):
|
||||||
|
doc_store_logger.info(f"ESConnection.sql get sql: {sql}")
|
||||||
|
sql = re.sub(r"[ `]+", " ", sql)
|
||||||
|
sql = sql.replace("%", "")
|
||||||
|
replaces = []
|
||||||
|
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||||
|
fld, v = r.group(1), r.group(3)
|
||||||
|
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||||
|
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||||
|
replaces.append(
|
||||||
|
("{}{}'{}'".format(
|
||||||
|
r.group(1),
|
||||||
|
r.group(2),
|
||||||
|
r.group(3)),
|
||||||
|
match))
|
||||||
|
|
||||||
|
for p, r in replaces:
|
||||||
|
sql = sql.replace(p, r, 1)
|
||||||
|
doc_store_logger.info(f"ESConnection.sql to es: {sql}")
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
try:
|
try:
|
||||||
page = self.es.search(
|
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
|
||||||
index=self.idxnm,
|
return res
|
||||||
scroll=scroll_time,
|
except ConnectionTimeout:
|
||||||
size=pagesize,
|
doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql)
|
||||||
body=q,
|
continue
|
||||||
_source=None
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
es_logger.error("ES scrolling fail. " + str(e))
|
doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e))
|
||||||
time.sleep(3)
|
return None
|
||||||
|
doc_store_logger.error("ESConnection.sql timeout for 3 times!")
|
||||||
sid = page['_scroll_id']
|
return None
|
||||||
scroll_size = page['hits']['total']["value"]
|
|
||||||
es_logger.info("[TOTAL]%d" % scroll_size)
|
|
||||||
# Start scrolling
|
|
||||||
while scroll_size > 0:
|
|
||||||
yield page["hits"]["hits"]
|
|
||||||
for _ in range(100):
|
|
||||||
try:
|
|
||||||
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
es_logger.error("ES scrolling fail. " + str(e))
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
# Update the scroll ID
|
|
||||||
sid = page['_scroll_id']
|
|
||||||
# Get the number of results that we returned in the last scroll
|
|
||||||
scroll_size = len(page['hits']['hits'])
|
|
||||||
|
|
||||||
|
|
||||||
ELASTICSEARCH = ESConnection()
|
|
||||||
|
436
rag/utils/infinity_conn.py
Normal file
436
rag/utils/infinity_conn.py
Normal file
@ -0,0 +1,436 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from typing import List, Dict
|
||||||
|
import infinity
|
||||||
|
from infinity.common import ConflictType, InfinityException
|
||||||
|
from infinity.index import IndexInfo, IndexType
|
||||||
|
from infinity.connection_pool import ConnectionPool
|
||||||
|
from rag import settings
|
||||||
|
from rag.settings import doc_store_logger
|
||||||
|
from rag.utils import singleton
|
||||||
|
import polars as pl
|
||||||
|
from polars.series.series import Series
|
||||||
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
from rag.utils.doc_store_conn import (
|
||||||
|
DocStoreConnection,
|
||||||
|
MatchExpr,
|
||||||
|
MatchTextExpr,
|
||||||
|
MatchDenseExpr,
|
||||||
|
FusionExpr,
|
||||||
|
OrderByExpr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def equivalent_condition_to_str(condition: dict) -> str:
|
||||||
|
assert "_id" not in condition
|
||||||
|
cond = list()
|
||||||
|
for k, v in condition.items():
|
||||||
|
if not isinstance(k, str) or not v:
|
||||||
|
continue
|
||||||
|
if isinstance(v, list):
|
||||||
|
inCond = list()
|
||||||
|
for item in v:
|
||||||
|
if isinstance(item, str):
|
||||||
|
inCond.append(f"'{item}'")
|
||||||
|
else:
|
||||||
|
inCond.append(str(item))
|
||||||
|
if inCond:
|
||||||
|
strInCond = ", ".join(inCond)
|
||||||
|
strInCond = f"{k} IN ({strInCond})"
|
||||||
|
cond.append(strInCond)
|
||||||
|
elif isinstance(v, str):
|
||||||
|
cond.append(f"{k}='{v}'")
|
||||||
|
else:
|
||||||
|
cond.append(f"{k}={str(v)}")
|
||||||
|
return " AND ".join(cond)
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class InfinityConnection(DocStoreConnection):
|
||||||
|
def __init__(self):
|
||||||
|
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||||
|
infinity_uri = settings.INFINITY["uri"]
|
||||||
|
if ":" in infinity_uri:
|
||||||
|
host, port = infinity_uri.split(":")
|
||||||
|
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||||
|
self.connPool = ConnectionPool(infinity_uri)
|
||||||
|
doc_store_logger.info(f"Connected to infinity {infinity_uri}.")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Database operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def dbType(self) -> str:
|
||||||
|
return "infinity"
|
||||||
|
|
||||||
|
def health(self) -> dict:
|
||||||
|
"""
|
||||||
|
Return the health status of the database.
|
||||||
|
TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
|
||||||
|
"""
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
res = infinity.show_current_node()
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
color = "green" if res.error_code == 0 else "red"
|
||||||
|
res2 = {
|
||||||
|
"type": "infinity",
|
||||||
|
"status": f"{res.role} {color}",
|
||||||
|
"error": res.error_msg,
|
||||||
|
}
|
||||||
|
return res2
|
||||||
|
|
||||||
|
"""
|
||||||
|
Table operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||||
|
table_name = f"{indexName}_{knowledgebaseId}"
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||||
|
|
||||||
|
fp_mapping = os.path.join(
|
||||||
|
get_project_base_directory(), "conf", "infinity_mapping.json"
|
||||||
|
)
|
||||||
|
if not os.path.exists(fp_mapping):
|
||||||
|
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||||
|
schema = json.load(open(fp_mapping))
|
||||||
|
vector_name = f"q_{vectorSize}_vec"
|
||||||
|
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
|
||||||
|
inf_table = inf_db.create_table(
|
||||||
|
table_name,
|
||||||
|
schema,
|
||||||
|
ConflictType.Ignore,
|
||||||
|
)
|
||||||
|
inf_table.create_index(
|
||||||
|
"q_vec_idx",
|
||||||
|
IndexInfo(
|
||||||
|
vector_name,
|
||||||
|
IndexType.Hnsw,
|
||||||
|
{
|
||||||
|
"M": "16",
|
||||||
|
"ef_construction": "50",
|
||||||
|
"metric": "cosine",
|
||||||
|
"encode": "lvq",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ConflictType.Ignore,
|
||||||
|
)
|
||||||
|
text_suffix = ["_tks", "_ltks", "_kwd"]
|
||||||
|
for field_name, field_info in schema.items():
|
||||||
|
if field_info["type"] != "varchar":
|
||||||
|
continue
|
||||||
|
for suffix in text_suffix:
|
||||||
|
if field_name.endswith(suffix):
|
||||||
|
inf_table.create_index(
|
||||||
|
f"text_idx_{field_name}",
|
||||||
|
IndexInfo(
|
||||||
|
field_name, IndexType.FullText, {"ANALYZER": "standard"}
|
||||||
|
),
|
||||||
|
ConflictType.Ignore,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
doc_store_logger.info(
|
||||||
|
f"INFINITY created table {table_name}, vector size {vectorSize}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||||
|
table_name = f"{indexName}_{knowledgebaseId}"
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
db_instance = inf_conn.get_database(self.dbName)
|
||||||
|
db_instance.drop_table(table_name, ConflictType.Ignore)
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
doc_store_logger.info(f"INFINITY dropped table {table_name}")
|
||||||
|
|
||||||
|
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||||
|
table_name = f"{indexName}_{knowledgebaseId}"
|
||||||
|
try:
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
db_instance = inf_conn.get_database(self.dbName)
|
||||||
|
_ = db_instance.get_table(table_name)
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
doc_store_logger.error("INFINITY indexExist: " + str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
"""
|
||||||
|
CRUD operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
selectFields: list[str],
|
||||||
|
highlightFields: list[str],
|
||||||
|
condition: dict,
|
||||||
|
matchExprs: list[MatchExpr],
|
||||||
|
orderBy: OrderByExpr,
|
||||||
|
offset: int,
|
||||||
|
limit: int,
|
||||||
|
indexNames: str|list[str],
|
||||||
|
knowledgebaseIds: list[str],
|
||||||
|
) -> list[dict] | pl.DataFrame:
|
||||||
|
"""
|
||||||
|
TODO: Infinity doesn't provide highlight
|
||||||
|
"""
|
||||||
|
if isinstance(indexNames, str):
|
||||||
|
indexNames = indexNames.split(",")
|
||||||
|
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
db_instance = inf_conn.get_database(self.dbName)
|
||||||
|
df_list = list()
|
||||||
|
table_list = list()
|
||||||
|
if "id" not in selectFields:
|
||||||
|
selectFields.append("id")
|
||||||
|
|
||||||
|
# Prepare expressions common to all tables
|
||||||
|
filter_cond = ""
|
||||||
|
filter_fulltext = ""
|
||||||
|
if condition:
|
||||||
|
filter_cond = equivalent_condition_to_str(condition)
|
||||||
|
for matchExpr in matchExprs:
|
||||||
|
if isinstance(matchExpr, MatchTextExpr):
|
||||||
|
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
|
||||||
|
matchExpr.extra_options.update({"filter": filter_cond})
|
||||||
|
fields = ",".join(matchExpr.fields)
|
||||||
|
filter_fulltext = (
|
||||||
|
f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
|
||||||
|
)
|
||||||
|
if len(filter_cond) != 0:
|
||||||
|
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
|
||||||
|
# doc_store_logger.info(f"filter_fulltext: {filter_fulltext}")
|
||||||
|
minimum_should_match = "0%"
|
||||||
|
if "minimum_should_match" in matchExpr.extra_options:
|
||||||
|
minimum_should_match = (
|
||||||
|
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
|
||||||
|
+ "%"
|
||||||
|
)
|
||||||
|
matchExpr.extra_options.update(
|
||||||
|
{"minimum_should_match": minimum_should_match}
|
||||||
|
)
|
||||||
|
for k, v in matchExpr.extra_options.items():
|
||||||
|
if not isinstance(v, str):
|
||||||
|
matchExpr.extra_options[k] = str(v)
|
||||||
|
elif isinstance(matchExpr, MatchDenseExpr):
|
||||||
|
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
|
||||||
|
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||||
|
for k, v in matchExpr.extra_options.items():
|
||||||
|
if not isinstance(v, str):
|
||||||
|
matchExpr.extra_options[k] = str(v)
|
||||||
|
if orderBy.fields:
|
||||||
|
order_by_expr_list = list()
|
||||||
|
for order_field in orderBy.fields:
|
||||||
|
order_by_expr_list.append((order_field[0], order_field[1] == 0))
|
||||||
|
|
||||||
|
# Scatter search tables and gather the results
|
||||||
|
for indexName in indexNames:
|
||||||
|
for knowledgebaseId in knowledgebaseIds:
|
||||||
|
table_name = f"{indexName}_{knowledgebaseId}"
|
||||||
|
try:
|
||||||
|
table_instance = db_instance.get_table(table_name)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
table_list.append(table_name)
|
||||||
|
builder = table_instance.output(selectFields)
|
||||||
|
for matchExpr in matchExprs:
|
||||||
|
if isinstance(matchExpr, MatchTextExpr):
|
||||||
|
fields = ",".join(matchExpr.fields)
|
||||||
|
builder = builder.match_text(
|
||||||
|
fields,
|
||||||
|
matchExpr.matching_text,
|
||||||
|
matchExpr.topn,
|
||||||
|
matchExpr.extra_options,
|
||||||
|
)
|
||||||
|
elif isinstance(matchExpr, MatchDenseExpr):
|
||||||
|
builder = builder.match_dense(
|
||||||
|
matchExpr.vector_column_name,
|
||||||
|
matchExpr.embedding_data,
|
||||||
|
matchExpr.embedding_data_type,
|
||||||
|
matchExpr.distance_type,
|
||||||
|
matchExpr.topn,
|
||||||
|
matchExpr.extra_options,
|
||||||
|
)
|
||||||
|
elif isinstance(matchExpr, FusionExpr):
|
||||||
|
builder = builder.fusion(
|
||||||
|
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
|
||||||
|
)
|
||||||
|
if orderBy.fields:
|
||||||
|
builder.sort(order_by_expr_list)
|
||||||
|
builder.offset(offset).limit(limit)
|
||||||
|
kb_res = builder.to_pl()
|
||||||
|
df_list.append(kb_res)
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
res = pl.concat(df_list)
|
||||||
|
doc_store_logger.info("INFINITY search tables: " + str(table_list))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
|
||||||
|
) -> dict | None:
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
db_instance = inf_conn.get_database(self.dbName)
|
||||||
|
df_list = list()
|
||||||
|
assert isinstance(knowledgebaseIds, list)
|
||||||
|
for knowledgebaseId in knowledgebaseIds:
|
||||||
|
table_name = f"{indexName}_{knowledgebaseId}"
|
||||||
|
table_instance = db_instance.get_table(table_name)
|
||||||
|
kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
||||||
|
df_list.append(kb_res)
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
res = pl.concat(df_list)
|
||||||
|
res_fields = self.getFields(res, res.columns)
|
||||||
|
return res_fields.get(chunkId, None)
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self, documents: list[dict], indexName: str, knowledgebaseId: str
|
||||||
|
) -> list[str]:
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
db_instance = inf_conn.get_database(self.dbName)
|
||||||
|
table_name = f"{indexName}_{knowledgebaseId}"
|
||||||
|
try:
|
||||||
|
table_instance = db_instance.get_table(table_name)
|
||||||
|
except InfinityException as e:
|
||||||
|
# src/common/status.cppm, kTableNotExist = 3022
|
||||||
|
if e.error_code != 3022:
|
||||||
|
raise
|
||||||
|
vector_size = 0
|
||||||
|
patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
|
||||||
|
for k in documents[0].keys():
|
||||||
|
m = patt.match(k)
|
||||||
|
if m:
|
||||||
|
vector_size = int(m.group("vector_size"))
|
||||||
|
break
|
||||||
|
if vector_size == 0:
|
||||||
|
raise ValueError("Cannot infer vector size from documents")
|
||||||
|
self.createIdx(indexName, knowledgebaseId, vector_size)
|
||||||
|
table_instance = db_instance.get_table(table_name)
|
||||||
|
|
||||||
|
for d in documents:
|
||||||
|
assert "_id" not in d
|
||||||
|
assert "id" in d
|
||||||
|
for k, v in d.items():
|
||||||
|
if k.endswith("_kwd") and isinstance(v, list):
|
||||||
|
d[k] = " ".join(v)
|
||||||
|
ids = [f"'{d["id"]}'" for d in documents]
|
||||||
|
str_ids = ", ".join(ids)
|
||||||
|
str_filter = f"id IN ({str_ids})"
|
||||||
|
table_instance.delete(str_filter)
|
||||||
|
# for doc in documents:
|
||||||
|
# doc_store_logger.info(f"insert position_list: {doc['position_list']}")
|
||||||
|
# doc_store_logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
|
||||||
|
table_instance.insert(documents)
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
doc_store_logger.info(f"inserted into {table_name} {str_ids}.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
|
||||||
|
) -> bool:
|
||||||
|
# if 'position_list' in newValue:
|
||||||
|
# doc_store_logger.info(f"update position_list: {newValue['position_list']}")
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
db_instance = inf_conn.get_database(self.dbName)
|
||||||
|
table_name = f"{indexName}_{knowledgebaseId}"
|
||||||
|
table_instance = db_instance.get_table(table_name)
|
||||||
|
filter = equivalent_condition_to_str(condition)
|
||||||
|
for k, v in newValue.items():
|
||||||
|
if k.endswith("_kwd") and isinstance(v, list):
|
||||||
|
newValue[k] = " ".join(v)
|
||||||
|
table_instance.update(filter, newValue)
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||||
|
inf_conn = self.connPool.get_conn()
|
||||||
|
db_instance = inf_conn.get_database(self.dbName)
|
||||||
|
table_name = f"{indexName}_{knowledgebaseId}"
|
||||||
|
filter = equivalent_condition_to_str(condition)
|
||||||
|
try:
|
||||||
|
table_instance = db_instance.get_table(table_name)
|
||||||
|
except Exception:
|
||||||
|
doc_store_logger.warning(
|
||||||
|
f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
res = table_instance.delete(filter)
|
||||||
|
self.connPool.release_conn(inf_conn)
|
||||||
|
return res.deleted_rows
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helper functions for search result
|
||||||
|
"""
|
||||||
|
|
||||||
|
def getTotal(self, res):
|
||||||
|
return len(res)
|
||||||
|
|
||||||
|
def getChunkIds(self, res):
|
||||||
|
return list(res["id"])
|
||||||
|
|
||||||
|
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
|
||||||
|
res_fields = {}
|
||||||
|
if not fields:
|
||||||
|
return {}
|
||||||
|
num_rows = len(res)
|
||||||
|
column_id = res["id"]
|
||||||
|
for i in range(num_rows):
|
||||||
|
id = column_id[i]
|
||||||
|
m = {"id": id}
|
||||||
|
for fieldnm in fields:
|
||||||
|
if fieldnm not in res:
|
||||||
|
m[fieldnm] = None
|
||||||
|
continue
|
||||||
|
v = res[fieldnm][i]
|
||||||
|
if isinstance(v, Series):
|
||||||
|
v = list(v)
|
||||||
|
elif fieldnm == "important_kwd":
|
||||||
|
assert isinstance(v, str)
|
||||||
|
v = v.split(" ")
|
||||||
|
else:
|
||||||
|
if not isinstance(v, str):
|
||||||
|
v = str(v)
|
||||||
|
# if fieldnm.endswith("_tks"):
|
||||||
|
# v = rmSpace(v)
|
||||||
|
m[fieldnm] = v
|
||||||
|
res_fields[id] = m
|
||||||
|
return res_fields
|
||||||
|
|
||||||
|
def getHighlight(self, res, keywords: List[str], fieldnm: str):
|
||||||
|
ans = {}
|
||||||
|
num_rows = len(res)
|
||||||
|
column_id = res["id"]
|
||||||
|
for i in range(num_rows):
|
||||||
|
id = column_id[i]
|
||||||
|
txt = res[fieldnm][i]
|
||||||
|
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||||
|
txts = []
|
||||||
|
for t in re.split(r"[.?!;\n]", txt):
|
||||||
|
for w in keywords:
|
||||||
|
t = re.sub(
|
||||||
|
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
|
||||||
|
% re.escape(w),
|
||||||
|
r"\1<em>\2</em>\3",
|
||||||
|
t,
|
||||||
|
flags=re.IGNORECASE | re.MULTILINE,
|
||||||
|
)
|
||||||
|
if not re.search(
|
||||||
|
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
txts.append(t)
|
||||||
|
ans[id] = "...".join(txts)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def getAggregation(self, res, fieldnm: str):
|
||||||
|
"""
|
||||||
|
TODO: Infinity doesn't provide aggregation
|
||||||
|
"""
|
||||||
|
return list()
|
||||||
|
|
||||||
|
"""
|
||||||
|
SQL
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sql(sql: str, fetch_size: int, format: str):
|
||||||
|
raise NotImplementedError("Not implemented")
|
@ -50,8 +50,8 @@ class Document(Base):
|
|||||||
return res.content
|
return res.content
|
||||||
|
|
||||||
|
|
||||||
def list_chunks(self,page=1, page_size=30, keywords="", id:str=None):
|
def list_chunks(self,page=1, page_size=30, keywords=""):
|
||||||
data={"keywords": keywords,"page":page,"page_size":page_size,"id":id}
|
data={"keywords": keywords,"page":page,"page_size":page_size}
|
||||||
res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
|
res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
|
||||||
res = res.json()
|
res = res.json()
|
||||||
if res.get("code") == 0:
|
if res.get("code") == 0:
|
||||||
|
@ -126,6 +126,7 @@ def test_delete_chunk_with_success(get_api_key_fixture):
|
|||||||
docs = ds.upload_documents(documents)
|
docs = ds.upload_documents(documents)
|
||||||
doc = docs[0]
|
doc = docs[0]
|
||||||
chunk = doc.add_chunk(content="This is a chunk addition test")
|
chunk = doc.add_chunk(content="This is a chunk addition test")
|
||||||
|
sleep(5)
|
||||||
doc.delete_chunks([chunk.id])
|
doc.delete_chunks([chunk.id])
|
||||||
|
|
||||||
|
|
||||||
@ -146,6 +147,8 @@ def test_update_chunk_content(get_api_key_fixture):
|
|||||||
docs = ds.upload_documents(documents)
|
docs = ds.upload_documents(documents)
|
||||||
doc = docs[0]
|
doc = docs[0]
|
||||||
chunk = doc.add_chunk(content="This is a chunk addition test")
|
chunk = doc.add_chunk(content="This is a chunk addition test")
|
||||||
|
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
|
||||||
|
sleep(3)
|
||||||
chunk.update({"content":"This is a updated content"})
|
chunk.update({"content":"This is a updated content"})
|
||||||
|
|
||||||
def test_update_chunk_available(get_api_key_fixture):
|
def test_update_chunk_available(get_api_key_fixture):
|
||||||
@ -165,7 +168,9 @@ def test_update_chunk_available(get_api_key_fixture):
|
|||||||
docs = ds.upload_documents(documents)
|
docs = ds.upload_documents(documents)
|
||||||
doc = docs[0]
|
doc = docs[0]
|
||||||
chunk = doc.add_chunk(content="This is a chunk addition test")
|
chunk = doc.add_chunk(content="This is a chunk addition test")
|
||||||
chunk.update({"available":False})
|
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
|
||||||
|
sleep(3)
|
||||||
|
chunk.update({"available":0})
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_chunks(get_api_key_fixture):
|
def test_retrieve_chunks(get_api_key_fixture):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user