Integration with Infinity (#2894)

### What problem does this PR solve?

Integration with Infinity

- Replaced ELASTICSEARCH with dataStoreConn
- Renamed deleteByQuery with delete
- Renamed bulk to upsertBulk
- getHighlight, getAggregation
- Fix KGSearch.search
- Moved Dealer.sql_retrieval to es_conn.py


### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu 2024-11-12 14:59:41 +08:00 committed by GitHub
parent 00b6000b76
commit f4c52371ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 2647 additions and 1878 deletions

View File

@ -78,7 +78,7 @@ jobs:
echo "Waiting for service to be available..." echo "Waiting for service to be available..."
sleep 5 sleep 5
done done
cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py cd sdk/python && poetry install && source .venv/bin/activate && cd test && pytest --tb=short t_dataset.py t_chat.py t_session.py t_document.py t_chunk.py
- name: Stop ragflow:dev - name: Stop ragflow:dev
if: always() # always run this step even if previous steps failed if: always() # always run this step even if previous steps failed

View File

@ -285,7 +285,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
git clone https://github.com/infiniflow/ragflow.git git clone https://github.com/infiniflow/ragflow.git
cd ragflow/ cd ragflow/
export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true
~/.local/bin/poetry install --sync --no-root # install RAGFlow dependent python modules ~/.local/bin/poetry install --sync --no-root --with=full # install RAGFlow dependent python modules
``` ```
3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose: 3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
@ -295,7 +295,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`: Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
``` ```
127.0.0.1 es01 mysql minio redis 127.0.0.1 es01 infinity mysql minio redis
``` ```
In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.

View File

@ -250,7 +250,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
`/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します: `/etc/hosts` に以下の行を追加して、**docker/service_conf.yaml** に指定されたすべてのホストを `127.0.0.1` に解決します:
``` ```
127.0.0.1 es01 mysql minio redis 127.0.0.1 es01 infinity mysql minio redis
``` ```
**docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り). **docker/service_conf.yaml** で mysql のポートを `5455` に、es のポートを `1200` に更新します(**docker/.env** に指定された通り).

View File

@ -254,7 +254,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
`/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다: `/etc/hosts` 에 다음 줄을 추가하여 **docker/service_conf.yaml** 에 지정된 모든 호스트를 `127.0.0.1` 로 해결합니다:
``` ```
127.0.0.1 es01 mysql minio redis 127.0.0.1 es01 infinity mysql minio redis
``` ```
**docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로). **docker/service_conf.yaml** 에서 mysql 포트를 `5455` 로, es 포트를 `1200` 으로 업데이트합니다( **docker/.env** 에 지정된 대로).

View File

@ -252,7 +252,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev .
`/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1` `/etc/hosts` 中添加以下代码,将 **docker/service_conf.yaml** 文件中的所有 host 地址都解析为 `127.0.0.1`
``` ```
127.0.0.1 es01 mysql minio redis 127.0.0.1 es01 infinity mysql minio redis
``` ```
在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`es 端口更新为 `1200` 在文件 **docker/service_conf.yaml** 中,对照 **docker/.env** 的配置将 mysql 端口更新为 `5455`es 端口更新为 `1200`

View File

@ -529,13 +529,14 @@ def list_chunks():
return get_json_result( return get_json_result(
data=False, message="Can't find doc_name or doc_id" data=False, message="Can't find doc_name or doc_id"
) )
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id) res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = [ res = [
{ {
"content": res_item["content_with_weight"], "content": res_item["content_with_weight"],
"doc_name": res_item["docnm_kwd"], "doc_name": res_item["docnm_kwd"],
"img_id": res_item["img_id"] "image_id": res_item["img_id"]
} for res_item in res } for res_item in res
] ]

View File

@ -18,12 +18,10 @@ import json
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from elasticsearch_dsl import Q
from api.db.services.dialog_service import keyword_extraction from api.db.services.dialog_service import keyword_extraction
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils import rmSpace from rag.utils import rmSpace
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -31,12 +29,11 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode, retrievaler, kg_retrievaler from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import hashlib import hashlib
import re import re
@manager.route('/list', methods=['POST']) @manager.route('/list', methods=['POST'])
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
@ -53,12 +50,13 @@ def list_chunk():
e, doc = DocumentService.get_by_id(doc_id) e, doc = DocumentService.get_by_id(doc_id)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
query = { query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
} }
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -69,16 +67,12 @@ def list_chunk():
"doc_id": sres.field[id]["doc_id"], "doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"], "docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []), "important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""), "image_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1), "available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", "").split("\t") "positions": json.loads(sres.field[id].get("position_list", "[]")),
} }
if len(d["positions"]) % 5 == 0: assert isinstance(d["positions"], list)
poss = [] assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
for i in range(0, len(d["positions"]), 5):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
res["chunks"].append(d) res["chunks"].append(d)
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
@ -96,22 +90,20 @@ def get():
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
if not tenants: if not tenants:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
res = ELASTICSEARCH.get( tenant_id = tenants[0].tenant_id
chunk_id, search.index_name(
tenants[0].tenant_id)) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
if not res.get("found"): chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None:
return server_error_response("Chunk not found") return server_error_response("Chunk not found")
id = res["_id"]
res = res["_source"]
res["chunk_id"] = id
k = [] k = []
for n in res.keys(): for n in chunk.keys():
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
k.append(n) k.append(n)
for n in k: for n in k:
del res[n] del chunk[n]
return get_json_result(data=res) return get_json_result(data=chunk)
except Exception as e: except Exception as e:
if str(e).find("NotFoundError") >= 0: if str(e).find("NotFoundError") >= 0:
return get_json_result(data=False, message='Chunk not found!', return get_json_result(data=False, message='Chunk not found!',
@ -162,7 +154,7 @@ def set():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -174,11 +166,11 @@ def set():
def switch(): def switch():
req = request.json req = request.json
try: try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not tenant_id: if not e:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Document not found!")
if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]], if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
search.index_name(tenant_id)): search.index_name(doc.tenant_id), doc.kb_id):
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
@ -191,12 +183,11 @@ def switch():
def rm(): def rm():
req = request.json req = request.json
try: try:
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
return get_data_error_result(message="Index updating failure")
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
return get_data_error_result(message="Index updating failure")
deleted_chunk_ids = req["chunk_ids"] deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids) chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
@ -239,7 +230,7 @@ def create():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
@ -256,8 +247,9 @@ def retrieval_test():
page = int(req.get("page", 1)) page = int(req.get("page", 1))
size = int(req.get("size", 30)) size = int(req.get("size", 30))
question = req["question"] question = req["question"]
kb_id = req["kb_id"] kb_ids = req["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id] if isinstance(kb_ids, str):
kb_ids = [kb_ids]
doc_ids = req.get("doc_ids", []) doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.0)) similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
@ -265,17 +257,17 @@ def retrieval_test():
try: try:
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for kid in kb_id: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kid): tenant_id=tenant.tenant_id, id=kb_id):
break break
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id[0]) e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") return get_data_error_result(message="Knowledgebase not found!")
@ -290,7 +282,7 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size, ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
for c in ranks["chunks"]: for c in ranks["chunks"]:
@ -309,12 +301,16 @@ def retrieval_test():
@login_required @login_required
def knowledge_graph(): def knowledge_graph():
doc_id = request.args["doc_id"] doc_id = request.args["doc_id"]
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(message="Document not found!")
tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
req = { req = {
"doc_ids":[doc_id], "doc_ids":[doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"] "knowledge_graph_kwd": ["graph", "mind_map"]
} }
tenant_id = DocumentService.get_tenant_id(doc_id) sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids, doc.kb_id)
sres = retrievaler.search(req, search.index_name(tenant_id))
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]: for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"] ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -17,7 +17,6 @@ import pathlib
import re import re
import flask import flask
from elasticsearch_dsl import Q
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
@ -27,14 +26,13 @@ from api.db.services.file_service import FileService
from api.db.services.task_service import TaskService, queue_tasks from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from rag.nlp import search from rag.nlp import search
from rag.utils.es_conn import ELASTICSEARCH
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.settings import RetCode from api.settings import RetCode, docStoreConn
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
@ -275,18 +273,8 @@ def change_status():
return get_data_error_result( return get_data_error_result(
message="Database error (Document update)!") message="Database error (Document update)!")
if str(req["status"]) == "0": status = int(req["status"])
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
scripts="ctx._source.available_int=0;",
idxnm=search.index_name(
kb.tenant_id)
)
else:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts="ctx._source.available_int=1;",
idxnm=search.index_name(
kb.tenant_id)
)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -365,8 +353,11 @@ def run():
tenant_id = DocumentService.get_tenant_id(id) tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
ELASTICSEARCH.deleteByQuery( e, doc = DocumentService.get_by_id(id)
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) if not e:
return get_data_error_result(message="Document not found!")
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value: if str(req["run"]) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
@ -490,8 +481,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
ELASTICSEARCH.deleteByQuery( if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:

View File

@ -28,6 +28,8 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
from api.settings import RetCode from api.settings import RetCode
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.settings import docStoreConn
from rag.nlp import search
@manager.route('/create', methods=['post']) @manager.route('/create', methods=['post'])
@ -166,6 +168,9 @@ def rm():
if not KnowledgebaseService.delete_by_id(req["kb_id"]): if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result( return get_data_error_result(
message="Database error (Knowledgebase removal)!") message="Database error (Knowledgebase removal)!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -30,7 +30,6 @@ from api.db.services.task_service import TaskService, queue_tasks
from api.utils.api_utils import server_error_response from api.utils.api_utils import server_error_response
from api.utils.api_utils import get_result, get_error_data_result from api.utils.api_utils import get_result, get_error_data_result
from io import BytesIO from io import BytesIO
from elasticsearch_dsl import Q
from flask import request, send_file from flask import request, send_file
from api.db import FileSource, TaskStatus, FileType from api.db import FileSource, TaskStatus, FileType
from api.db.db_models import File from api.db.db_models import File
@ -42,7 +41,7 @@ from api.settings import RetCode, retrievaler
from api.utils.api_utils import construct_json_result, get_parser_config from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search from rag.nlp import search
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.utils.es_conn import ELASTICSEARCH from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
import os import os
@ -293,9 +292,7 @@ def update_doc(tenant_id, dataset_id, document_id):
) )
if not e: if not e:
return get_error_data_result(message="Document not found!") return get_error_data_result(message="Document not found!")
ELASTICSEARCH.deleteByQuery( docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)
)
return get_result() return get_result()
@ -647,9 +644,7 @@ def parse(tenant_id, dataset_id):
info["chunk_num"] = 0 info["chunk_num"] = 0
info["token_num"] = 0 info["token_num"] = 0
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
ELASTICSEARCH.deleteByQuery( docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
)
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id) e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict() doc = doc.to_dict()
@ -713,9 +708,7 @@ def stop_parsing(tenant_id, dataset_id):
) )
info = {"run": "2", "progress": 0, "chunk_num": 0} info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
ELASTICSEARCH.deleteByQuery( docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
Q("match", doc_id=id), idxnm=search.index_name(tenant_id)
)
return get_result() return get_result()
@ -812,7 +805,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
"question": question, "question": question,
"sort": True, "sort": True,
} }
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
key_mapping = { key_mapping = {
"chunk_num": "chunk_count", "chunk_num": "chunk_count",
"kb_id": "dataset_id", "kb_id": "dataset_id",
@ -833,51 +825,56 @@ def list_chunks(tenant_id, dataset_id, document_id):
renamed_doc[new_key] = value renamed_doc[new_key] = value
if key == "run": if key == "run":
renamed_doc["run"] = run_mapping.get(str(value)) renamed_doc["run"] = run_mapping.get(str(value))
res = {"total": sres.total, "chunks": [], "doc": renamed_doc}
origin_chunks = []
sign = 0
for id in sres.ids:
d = {
"chunk_id": id,
"content_with_weight": (
rmSpace(sres.highlight[id])
if question and id in sres.highlight
else sres.field[id].get("content_with_weight", "")
),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", "").split("\t"),
}
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append(
[
float(d["positions"][i]),
float(d["positions"][i + 1]),
float(d["positions"][i + 2]),
float(d["positions"][i + 3]),
float(d["positions"][i + 4]),
]
)
d["positions"] = poss
origin_chunks.append(d) res = {"total": 0, "chunks": [], "doc": renamed_doc}
origin_chunks = []
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
sign = 0
for id in sres.ids:
d = {
"id": id,
"content_with_weight": (
rmSpace(sres.highlight[id])
if question and id in sres.highlight
else sres.field[id].get("content_with_weight", "")
),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", "").split("\t"),
}
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append(
[
float(d["positions"][i]),
float(d["positions"][i + 1]),
float(d["positions"][i + 2]),
float(d["positions"][i + 3]),
float(d["positions"][i + 4]),
]
)
d["positions"] = poss
origin_chunks.append(d)
if req.get("id"):
if req.get("id") == id:
origin_chunks.clear()
origin_chunks.append(d)
sign = 1
break
if req.get("id"): if req.get("id"):
if req.get("id") == id: if sign == 0:
origin_chunks.clear() return get_error_data_result(f"Can't find this chunk {req.get('id')}")
origin_chunks.append(d)
sign = 1
break
if req.get("id"):
if sign == 0:
return get_error_data_result(f"Can't find this chunk {req.get('id')}")
for chunk in origin_chunks: for chunk in origin_chunks:
key_mapping = { key_mapping = {
"chunk_id": "id", "id": "id",
"content_with_weight": "content", "content_with_weight": "content",
"doc_id": "document_id", "doc_id": "document_id",
"important_kwd": "important_keywords", "important_kwd": "important_keywords",
@ -996,9 +993,9 @@ def add_chunk(tenant_id, dataset_id, document_id):
) )
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
d["kb_id"] = [doc.kb_id] d["kb_id"] = dataset_id
d["docnm_kwd"] = doc.name d["docnm_kwd"] = doc.name
d["doc_id"] = doc.id d["doc_id"] = document_id
embd_id = DocumentService.get_embd_id(document_id) embd_id = DocumentService.get_embd_id(document_id)
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id tenant_id, LLMType.EMBEDDING.value, embd_id
@ -1006,14 +1003,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"]]) v, c = embd_mdl.encode([doc.name, req["content"]])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
d["chunk_id"] = chunk_id
d["kb_id"] = doc.kb_id
# rename keys # rename keys
key_mapping = { key_mapping = {
"chunk_id": "id", "id": "id",
"content_with_weight": "content", "content_with_weight": "content",
"doc_id": "document_id", "doc_id": "document_id",
"important_kwd": "important_keywords", "important_kwd": "important_keywords",
@ -1079,36 +1074,16 @@ def rm_chunk(tenant_id, dataset_id, document_id):
""" """
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
doc = DocumentService.query(id=document_id, kb_id=dataset_id)
if not doc:
return get_error_data_result(
message=f"You don't own the document {document_id}."
)
doc = doc[0]
req = request.json req = request.json
if not req.get("chunk_ids"): condition = {"doc_id": document_id}
return get_error_data_result("`chunk_ids` is required") if "chunk_ids" in req:
query = {"doc_ids": [doc.id], "page": 1, "size": 1024, "question": "", "sort": True} condition["id"] = req["chunk_ids"]
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if not req: if chunk_number != 0:
chunk_ids = None DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
else: if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
chunk_ids = req.get("chunk_ids") return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(req["chunk_ids"])}")
if not chunk_ids: return get_result(message=f"deleted {chunk_number} chunks")
chunk_list = sres.ids
else:
chunk_list = chunk_ids
for chunk_id in chunk_list:
if chunk_id not in sres.ids:
return get_error_data_result(f"Chunk {chunk_id} not found")
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=chunk_list), search.index_name(tenant_id)
):
return get_error_data_result(message="Index updating failure")
deleted_chunk_ids = chunk_list
chunk_number = len(deleted_chunk_ids)
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
return get_result()
@manager.route( @manager.route(
@ -1168,9 +1143,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema: schema:
type: object type: object
""" """
try: chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
res = ELASTICSEARCH.get(chunk_id, search.index_name(tenant_id)) if chunk is None:
except Exception:
return get_error_data_result(f"Can't find this chunk {chunk_id}") return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
@ -1180,19 +1154,12 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
message=f"You don't own the document {document_id}." message=f"You don't own the document {document_id}."
) )
doc = doc[0] doc = doc[0]
query = {
"doc_ids": [document_id],
"page": 1,
"size": 1024,
"question": "",
"sort": True,
}
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
if chunk_id not in sres.ids:
return get_error_data_result(f"You don't own the chunk {chunk_id}")
req = request.json req = request.json
content = res["_source"].get("content_with_weight") if "content" in req:
d = {"id": chunk_id, "content_with_weight": req.get("content", content)} content = req["content"]
else:
content = chunk.get("content_with_weight", "")
d = {"id": chunk_id, "content_with_weight": content}
d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"]) d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_keywords" in req: if "important_keywords" in req:
@ -1220,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result() return get_result()

View File

@ -31,7 +31,7 @@ from api.utils.api_utils import (
generate_confirmation_token, generate_confirmation_token,
) )
from api.versions import get_rag_version from api.versions import get_rag_version
from rag.utils.es_conn import ELASTICSEARCH from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
from timeit import default_timer as timer from timeit import default_timer as timer
@ -98,10 +98,11 @@ def status():
res = {} res = {}
st = timer() st = timer()
try: try:
res["es"] = ELASTICSEARCH.health() res["doc_store"] = docStoreConn.health()
res["es"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e: except Exception as e:
res["es"] = { res["doc_store"] = {
"type": "unknown",
"status": "red", "status": "red",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0), "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
"error": str(e), "error": str(e),

View File

@ -470,7 +470,7 @@ class User(DataBaseModel, UserMixin):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True) is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
@ -525,7 +525,7 @@ class Tenant(DataBaseModel):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
@ -542,7 +542,7 @@ class UserTenant(DataBaseModel):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
@ -559,7 +559,7 @@ class InvitationCode(DataBaseModel):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
@ -582,7 +582,7 @@ class LLMFactories(DataBaseModel):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
@ -616,7 +616,7 @@ class LLM(DataBaseModel):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
@ -703,7 +703,7 @@ class Knowledgebase(DataBaseModel):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
@ -767,7 +767,7 @@ class Document(DataBaseModel):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
@ -904,7 +904,7 @@ class Dialog(DataBaseModel):
status = CharField( status = CharField(
max_length=1, max_length=1,
null=True, null=True,
help_text="is it validate(0: wasted1: validate)", help_text="is it validate(0: wasted, 1: validate)",
default="1", default="1",
index=True) index=True)
@ -987,7 +987,7 @@ def migrate_db():
help_text="where dose this document come from", help_text="where dose this document come from",
index=True)) index=True))
) )
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
@ -996,7 +996,7 @@ def migrate_db():
help_text="default rerank model ID")) help_text="default rerank model ID"))
) )
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
@ -1004,59 +1004,59 @@ def migrate_db():
help_text="default rerank model ID")) help_text="default rerank model ID"))
) )
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
migrator.add_column('dialog', 'top_k', IntegerField(default=1024)) migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
) )
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
migrator.alter_column_type('tenant_llm', 'api_key', migrator.alter_column_type('tenant_llm', 'api_key',
CharField(max_length=1024, null=True, help_text="API KEY", index=True)) CharField(max_length=1024, null=True, help_text="API KEY", index=True))
) )
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
migrator.add_column('api_token', 'source', migrator.add_column('api_token', 'source',
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)) CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
) )
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
migrator.add_column("tenant","tts_id", migrator.add_column("tenant","tts_id",
CharField(max_length=256,null=True,help_text="default tts model ID",index=True)) CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
) )
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
migrator.add_column('api_4_conversation', 'source', migrator.add_column('api_4_conversation', 'source',
CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)) CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
) )
except Exception as e: except Exception:
pass pass
try: try:
DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;') DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);') DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
migrator.add_column('task', 'retry_count', IntegerField(default=0)) migrator.add_column('task', 'retry_count', IntegerField(default=0))
) )
except Exception as e: except Exception:
pass pass
try: try:
migrate( migrate(
migrator.alter_column_type('api_token', 'dialog_id', migrator.alter_column_type('api_token', 'dialog_id',
CharField(max_length=32, null=True, index=True)) CharField(max_length=32, null=True, index=True))
) )
except Exception as e: except Exception:
pass pass

View File

@ -15,7 +15,6 @@
# #
import hashlib import hashlib
import json import json
import os
import random import random
import re import re
import traceback import traceback
@ -24,16 +23,13 @@ from copy import deepcopy
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
from elasticsearch_dsl import Q
from peewee import fn from peewee import fn
from api.db.db_utils import bulk_insert_into_db from api.db.db_utils import bulk_insert_into_db
from api.settings import stat_logger from api.settings import stat_logger, docStoreConn
from api.utils import current_timestamp, get_format_time, get_uuid from api.utils import current_timestamp, get_format_time, get_uuid
from api.utils.file_utils import get_project_base_directory
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME from rag.settings import SVR_QUEUE_NAME
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
@ -112,8 +108,7 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def remove_document(cls, doc, tenant_id): def remove_document(cls, doc, tenant_id):
ELASTICSEARCH.deleteByQuery( docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
cls.clear_chunk_num(doc.id) cls.clear_chunk_num(doc.id)
return cls.delete_by_id(doc.id) return cls.delete_by_id(doc.id)
@ -225,6 +220,15 @@ class DocumentService(CommonService):
return return
return docs[0]["tenant_id"] return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def get_knowledgebase_id(cls, doc_id):
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
docs = docs.dicts()
if not docs:
return
return docs[0]["kb_id"]
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_tenant_id_by_name(cls, name): def get_tenant_id_by_name(cls, name):
@ -438,11 +442,6 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
if not e: if not e:
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")
idxnm = search.index_name(kb.tenant_id)
if not ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
err, files = FileService.upload_document(kb, file_objs, user_id) err, files = FileService.upload_document(kb, file_objs, user_id)
@ -486,7 +485,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update((ck["content_with_weight"] + md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8")) str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest() d["id"] = md5.hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp() d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"): if not d.get("image"):
@ -499,8 +498,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
else: else:
d["image"].save(output_buffer, format='JPEG') d["image"].save(output_buffer, format='JPEG')
STORAGE_IMPL.put(kb.id, d["_id"], output_buffer.getvalue()) STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["_id"]) d["img_id"] = "{}-{}".format(kb.id, d["id"])
del d["image"] del d["image"]
docs.append(d) docs.append(d)
@ -520,6 +519,9 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
token_counts[doc_id] += c token_counts[doc_id] += c
return vects return vects
idxnm = search.index_name(kb.tenant_id)
try_create_idx = True
_, tenant = TenantService.get_by_id(kb.tenant_id) _, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id) llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
for doc_id in docids: for doc_id in docids:
@ -550,7 +552,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
v = vects[i] v = vects[i]
d["q_%d_vec" % len(v)] = v d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm) if try_create_idx:
if not docStoreConn.indexExist(idxnm, kb_id):
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

View File

@ -66,6 +66,16 @@ class KnowledgebaseService(CommonService):
return list(kbs.dicts()) return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_kb_ids(cls, tenant_id):
fields = [
cls.model.id,
]
kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
kb_ids = [kb["id"] for kb in kbs]
return kb_ids
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_detail(cls, kb_id): def get_detail(cls, kb_id):

View File

@ -18,6 +18,8 @@ from datetime import date
from enum import IntEnum, Enum from enum import IntEnum, Enum
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger from api.utils.log_utils import LoggerFactory, getLogger
import rag.utils.es_conn
import rag.utils.infinity_conn
# Logger # Logger
LoggerFactory.set_directory( LoggerFactory.set_directory(
@ -33,7 +35,7 @@ access_logger = getLogger("access")
database_logger = getLogger("database") database_logger = getLogger("database")
chat_logger = getLogger("chat") chat_logger = getLogger("chat")
from rag.utils.es_conn import ELASTICSEARCH import rag.utils
from rag.nlp import search from rag.nlp import search
from graphrag import search as kg_search from graphrag import search as kg_search
from api.utils import get_base_config, decrypt_database_config from api.utils import get_base_config, decrypt_database_config
@ -206,8 +208,12 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = [] PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False CHECK_NODES_IDENTITY = False
retrievaler = search.Dealer(ELASTICSEARCH) if 'username' in get_base_config("es", {}):
kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH) docStoreConn = rag.utils.es_conn.ESConnection()
else:
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
retrievaler = search.Dealer(docStoreConn)
kg_retrievaler = kg_search.KGSearch(docStoreConn)
class CustomEnum(Enum): class CustomEnum(Enum):

View File

@ -126,10 +126,6 @@ def server_error_response(e):
if len(e.args) > 1: if len(e.args) > 1:
return get_json_result( return get_json_result(
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
@ -270,10 +266,6 @@ def construct_error_response(e):
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
@ -295,7 +287,7 @@ def token_required(func):
return decorated_function return decorated_function
def get_result(code=RetCode.SUCCESS, message='error', data=None): def get_result(code=RetCode.SUCCESS, message="", data=None):
if code == 0: if code == 0:
if data is not None: if data is not None:
response = {"code": code, "data": data} response = {"code": code, "data": data}

View File

@ -0,0 +1,26 @@
{
"id": {"type": "varchar", "default": ""},
"doc_id": {"type": "varchar", "default": ""},
"kb_id": {"type": "varchar", "default": ""},
"create_time": {"type": "varchar", "default": ""},
"create_timestamp_flt": {"type": "float", "default": 0.0},
"img_id": {"type": "varchar", "default": ""},
"docnm_kwd": {"type": "varchar", "default": ""},
"title_tks": {"type": "varchar", "default": ""},
"title_sm_tks": {"type": "varchar", "default": ""},
"name_kwd": {"type": "varchar", "default": ""},
"important_kwd": {"type": "varchar", "default": ""},
"important_tks": {"type": "varchar", "default": ""},
"content_with_weight": {"type": "varchar", "default": ""},
"content_ltks": {"type": "varchar", "default": ""},
"content_sm_ltks": {"type": "varchar", "default": ""},
"page_num_list": {"type": "varchar", "default": ""},
"top_list": {"type": "varchar", "default": ""},
"position_list": {"type": "varchar", "default": ""},
"weight_int": {"type": "integer", "default": 0},
"weight_flt": {"type": "float", "default": 0.0},
"rank_int": {"type": "integer", "default": 0},
"available_int": {"type": "integer", "default": 1},
"knowledge_graph_kwd": {"type": "varchar", "default": ""},
"entities_kwd": {"type": "varchar", "default": ""}
}

View File

@ -1,200 +1,203 @@
{ {
"settings": { "settings": {
"index": { "index": {
"number_of_shards": 2, "number_of_shards": 2,
"number_of_replicas": 0, "number_of_replicas": 0,
"refresh_interval" : "1000ms" "refresh_interval": "1000ms"
}, },
"similarity": { "similarity": {
"scripted_sim": { "scripted_sim": {
"type": "scripted", "type": "scripted",
"script": { "script": {
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);" "source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
}
} }
}
} }
}, },
"mappings": { "mappings": {
"properties": { "properties": {
"lat_lon": {"type": "geo_point", "store":"true"} "lat_lon": {
}, "type": "geo_point",
"date_detection": "true", "store": "true"
"dynamic_templates": [ }
{ },
"int": { "date_detection": "true",
"match": "*_int", "dynamic_templates": [
"mapping": { {
"type": "integer", "int": {
"store": "true" "match": "*_int",
} "mapping": {
"type": "integer",
"store": "true"
} }
},
{
"ulong": {
"match": "*_ulong",
"mapping": {
"type": "unsigned_long",
"store": "true"
}
}
},
{
"long": {
"match": "*_long",
"mapping": {
"type": "long",
"store": "true"
}
}
},
{
"short": {
"match": "*_short",
"mapping": {
"type": "short",
"store": "true"
}
}
},
{
"numeric": {
"match": "*_flt",
"mapping": {
"type": "float",
"store": true
}
}
},
{
"tks": {
"match": "*_tks",
"mapping": {
"type": "text",
"similarity": "scripted_sim",
"analyzer": "whitespace",
"store": true
}
}
},
{
"ltks":{
"match": "*_ltks",
"mapping": {
"type": "text",
"analyzer": "whitespace",
"store": true
}
}
},
{
"kwd": {
"match_pattern": "regex",
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
"mapping": {
"type": "keyword",
"similarity": "boolean",
"store": true
}
}
},
{
"dt": {
"match_pattern": "regex",
"match": "^.*(_dt|_time|_at)$",
"mapping": {
"type": "date",
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
"store": true
}
}
},
{
"nested": {
"match": "*_nst",
"mapping": {
"type": "nested"
}
}
},
{
"object": {
"match": "*_obj",
"mapping": {
"type": "object",
"dynamic": "true"
}
}
},
{
"string": {
"match": "*_with_weight",
"mapping": {
"type": "text",
"index": "false",
"store": true
}
}
},
{
"string": {
"match": "*_fea",
"mapping": {
"type": "rank_feature"
}
}
},
{
"dense_vector": {
"match": "*_512_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 512
}
}
},
{
"dense_vector": {
"match": "*_768_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 768
}
}
},
{
"dense_vector": {
"match": "*_1024_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 1024
}
}
},
{
"dense_vector": {
"match": "*_1536_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 1536
}
}
},
{
"binary": {
"match": "*_bin",
"mapping": {
"type": "binary"
}
}
} }
] },
} {
} "ulong": {
"match": "*_ulong",
"mapping": {
"type": "unsigned_long",
"store": "true"
}
}
},
{
"long": {
"match": "*_long",
"mapping": {
"type": "long",
"store": "true"
}
}
},
{
"short": {
"match": "*_short",
"mapping": {
"type": "short",
"store": "true"
}
}
},
{
"numeric": {
"match": "*_flt",
"mapping": {
"type": "float",
"store": true
}
}
},
{
"tks": {
"match": "*_tks",
"mapping": {
"type": "text",
"similarity": "scripted_sim",
"analyzer": "whitespace",
"store": true
}
}
},
{
"ltks": {
"match": "*_ltks",
"mapping": {
"type": "text",
"analyzer": "whitespace",
"store": true
}
}
},
{
"kwd": {
"match_pattern": "regex",
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
"mapping": {
"type": "keyword",
"similarity": "boolean",
"store": true
}
}
},
{
"dt": {
"match_pattern": "regex",
"match": "^.*(_dt|_time|_at)$",
"mapping": {
"type": "date",
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
"store": true
}
}
},
{
"nested": {
"match": "*_nst",
"mapping": {
"type": "nested"
}
}
},
{
"object": {
"match": "*_obj",
"mapping": {
"type": "object",
"dynamic": "true"
}
}
},
{
"string": {
"match": "*_(with_weight|list)$",
"mapping": {
"type": "text",
"index": "false",
"store": true
}
}
},
{
"string": {
"match": "*_fea",
"mapping": {
"type": "rank_feature"
}
}
},
{
"dense_vector": {
"match": "*_512_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 512
}
}
},
{
"dense_vector": {
"match": "*_768_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 768
}
}
},
{
"dense_vector": {
"match": "*_1024_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 1024
}
}
},
{
"dense_vector": {
"match": "*_1536_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine",
"dims": 1536
}
}
},
{
"binary": {
"match": "*_bin",
"mapping": {
"type": "binary"
}
}
}
]
}
}

View File

@ -19,6 +19,11 @@ KIBANA_PASSWORD=infini_rag_flow
# Update it according to the available memory in the host machine. # Update it according to the available memory in the host machine.
MEM_LIMIT=8073741824 MEM_LIMIT=8073741824
# Port to expose Infinity API to the host
INFINITY_THRIFT_PORT=23817
INFINITY_HTTP_PORT=23820
INFINITY_PSQL_PORT=5432
# The password for MySQL. # The password for MySQL.
# When updated, you must revise the `mysql.password` entry in service_conf.yaml. # When updated, you must revise the `mysql.password` entry in service_conf.yaml.
MYSQL_PASSWORD=infini_rag_flow MYSQL_PASSWORD=infini_rag_flow

View File

@ -6,6 +6,7 @@ services:
- esdata01:/usr/share/elasticsearch/data - esdata01:/usr/share/elasticsearch/data
ports: ports:
- ${ES_PORT}:9200 - ${ES_PORT}:9200
env_file: .env
environment: environment:
- node.name=es01 - node.name=es01
- ELASTIC_PASSWORD=${ELASTIC_PASSWORD} - ELASTIC_PASSWORD=${ELASTIC_PASSWORD}
@ -27,12 +28,40 @@ services:
retries: 120 retries: 120
networks: networks:
- ragflow - ragflow
restart: always restart: on-failure
# infinity:
# container_name: ragflow-infinity
# image: infiniflow/infinity:v0.5.0-dev2
# volumes:
# - infinity_data:/var/infinity
# ports:
# - ${INFINITY_THRIFT_PORT}:23817
# - ${INFINITY_HTTP_PORT}:23820
# - ${INFINITY_PSQL_PORT}:5432
# env_file: .env
# environment:
# - TZ=${TIMEZONE}
# mem_limit: ${MEM_LIMIT}
# ulimits:
# nofile:
# soft: 500000
# hard: 500000
# networks:
# - ragflow
# healthcheck:
# test: ["CMD", "curl", "http://localhost:23820/admin/node/current"]
# interval: 10s
# timeout: 10s
# retries: 120
# restart: on-failure
mysql: mysql:
# mysql:5.7 linux/arm64 image is unavailable. # mysql:5.7 linux/arm64 image is unavailable.
image: mysql:8.0.39 image: mysql:8.0.39
container_name: ragflow-mysql container_name: ragflow-mysql
env_file: .env
environment: environment:
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD} - MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
- TZ=${TIMEZONE} - TZ=${TIMEZONE}
@ -55,7 +84,7 @@ services:
interval: 10s interval: 10s
timeout: 10s timeout: 10s
retries: 3 retries: 3
restart: always restart: on-failure
minio: minio:
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
@ -64,6 +93,7 @@ services:
ports: ports:
- ${MINIO_PORT}:9000 - ${MINIO_PORT}:9000
- ${MINIO_CONSOLE_PORT}:9001 - ${MINIO_CONSOLE_PORT}:9001
env_file: .env
environment: environment:
- MINIO_ROOT_USER=${MINIO_USER} - MINIO_ROOT_USER=${MINIO_USER}
- MINIO_ROOT_PASSWORD=${MINIO_PASSWORD} - MINIO_ROOT_PASSWORD=${MINIO_PASSWORD}
@ -72,25 +102,28 @@ services:
- minio_data:/data - minio_data:/data
networks: networks:
- ragflow - ragflow
restart: always restart: on-failure
redis: redis:
image: valkey/valkey:8 image: valkey/valkey:8
container_name: ragflow-redis container_name: ragflow-redis
command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru command: redis-server --requirepass ${REDIS_PASSWORD} --maxmemory 128mb --maxmemory-policy allkeys-lru
env_file: .env
ports: ports:
- ${REDIS_PORT}:6379 - ${REDIS_PORT}:6379
volumes: volumes:
- redis_data:/data - redis_data:/data
networks: networks:
- ragflow - ragflow
restart: always restart: on-failure
volumes: volumes:
esdata01: esdata01:
driver: local driver: local
infinity_data:
driver: local
mysql_data: mysql_data:
driver: local driver: local
minio_data: minio_data:

View File

@ -1,6 +1,5 @@
include: include:
- path: ./docker-compose-base.yml - ./docker-compose-base.yml
env_file: ./.env
services: services:
ragflow: ragflow:
@ -15,19 +14,21 @@ services:
- ${SVR_HTTP_PORT}:9380 - ${SVR_HTTP_PORT}:9380
- 80:80 - 80:80
- 443:443 - 443:443
- 5678:5678
volumes: volumes:
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml - ./service_conf.yaml:/ragflow/conf/service_conf.yaml
- ./ragflow-logs:/ragflow/logs - ./ragflow-logs:/ragflow/logs
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf - ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
- ./nginx/proxy.conf:/etc/nginx/proxy.conf - ./nginx/proxy.conf:/etc/nginx/proxy.conf
- ./nginx/nginx.conf:/etc/nginx/nginx.conf - ./nginx/nginx.conf:/etc/nginx/nginx.conf
env_file: .env
environment: environment:
- TZ=${TIMEZONE} - TZ=${TIMEZONE}
- HF_ENDPOINT=${HF_ENDPOINT} - HF_ENDPOINT=${HF_ENDPOINT}
- MACOS=${MACOS} - MACOS=${MACOS}
networks: networks:
- ragflow - ragflow
restart: always restart: on-failure
# https://docs.docker.com/engine/daemon/prometheus/#create-a-prometheus-configuration
# If you're using Docker Desktop, the --add-host flag is optional. This flag makes sure that the host's internal IP gets exposed to the Prometheus container.
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"

View File

@ -67,7 +67,7 @@ docker compose -f docker/docker-compose-base.yml up -d
1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`: 1. Add the following line to `/etc/hosts` to resolve all hosts specified in **docker/service_conf.yaml** to `127.0.0.1`:
``` ```
127.0.0.1 es01 mysql minio redis 127.0.0.1 es01 infinity mysql minio redis
``` ```
2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**. 2. In **docker/service_conf.yaml**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.

View File

@ -1280,7 +1280,7 @@ Success:
"document_keyword": "1.txt", "document_keyword": "1.txt",
"highlight": "<em>ragflow</em> content", "highlight": "<em>ragflow</em> content",
"id": "d78435d142bd5cf6704da62c778795c5", "id": "d78435d142bd5cf6704da62c778795c5",
"img_id": "", "image_id": "",
"important_keywords": [ "important_keywords": [
"" ""
], ],

View File

@ -1351,7 +1351,7 @@ A list of `Chunk` objects representing references to the message, each containin
The chunk ID. The chunk ID.
- `content` `str` - `content` `str`
The content of the chunk. The content of the chunk.
- `image_id` `str` - `img_id` `str`
The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file. The ID of the snapshot of the chunk. Applicable only when the source of the chunk is an image, PPT, PPTX, or PDF file.
- `document_id` `str` - `document_id` `str`
The ID of the referenced document. The ID of the referenced document.

View File

@ -254,9 +254,12 @@ if __name__ == "__main__":
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api.settings import retrievaler
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=12, fields=["content_with_weight"])] docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
info = { info = {
"input_text": docs, "input_text": docs,
"entity_specs": "organization, person", "entity_specs": "organization, person",

View File

@ -15,95 +15,90 @@
# #
import json import json
from copy import deepcopy from copy import deepcopy
from typing import Dict
import pandas as pd import pandas as pd
from elasticsearch_dsl import Q, Search from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
from rag.nlp.search import Dealer from rag.nlp.search import Dealer
class KGSearch(Dealer): class KGSearch(Dealer):
def search(self, req, idxnm, emb_mdl=None, highlight=False): def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
def merge_into_first(sres, title=""): def merge_into_first(sres, title="") -> Dict[str, str]:
df,texts = [],[] if not sres:
for d in sres["hits"]["hits"]: return {}
content_with_weight = ""
df, texts = [],[]
for d in sres.values():
try: try:
df.append(json.loads(d["_source"]["content_with_weight"])) df.append(json.loads(d["content_with_weight"]))
except Exception as e: except Exception:
texts.append(d["_source"]["content_with_weight"]) texts.append(d["content_with_weight"])
pass
if not df and not texts: return False
if df: if df:
try: content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
except Exception as e:
pass
else: else:
sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts) content_with_weight = title + "\n" + "\n".join(texts)
return True first_id = ""
first_source = {}
for k, v in sres.items():
first_id = id
first_source = deepcopy(v)
break
first_source["content_with_weight"] = content_with_weight
first_id = next(iter(sres))
return {first_id: first_source}
qst = req.get("question", "")
matchText, keywords = self.qryr.question(qst, min_match=0.05)
condition = self.get_filters(req)
## Entity retrieval
condition.update({"knowledge_graph_kwd": ["entity"]})
assert emb_mdl, "No embedding model selected"
matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd", "doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight", "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
"weight_int", "weight_flt", "rank_int" "weight_int", "weight_flt", "rank_int"
]) ])
qst = req.get("question", "") fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
binary_query, keywords = self.qryr.question(qst, min_match="5%")
binary_query = self._add_filters(binary_query, req)
## Entity retrieval ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
bqry = deepcopy(binary_query) ent_res_fields = self.dataStore.getFields(ent_res, src)
bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"])) entities = [d["name_kwd"] for d in ent_res_fields.values()]
s = Search() ent_ids = self.dataStore.getChunkIds(ent_res)
s = s.query(bqry)[0: 32] ent_content = merge_into_first(ent_res_fields, "-Entities-")
if ent_content:
s = s.to_dict() ent_ids = list(ent_content.keys())
q_vec = []
if req.get("vector"):
assert emb_mdl, "No embedding model selected"
s["knn"] = self._vector(
qst, emb_mdl, req.get(
"similarity", 0.1), 1024)
s["knn"]["filter"] = bqry.to_dict()
q_vec = s["knn"]["query_vector"]
ent_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src)
entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
ent_ids = self.es.getDocIds(ent_res)
if merge_into_first(ent_res, "-Entities-"):
ent_ids = ent_ids[0:1]
## Community retrieval ## Community retrieval
bqry = deepcopy(binary_query) condition = self.get_filters(req)
bqry.filter.append(Q("terms", entities_kwd=entities)) condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"])) comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
s = Search() comm_res_fields = self.dataStore.getFields(comm_res, src)
s = s.query(bqry)[0: 32] comm_ids = self.dataStore.getChunkIds(comm_res)
s = s.to_dict() comm_content = merge_into_first(comm_res_fields, "-Community Report-")
comm_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src) if comm_content:
comm_ids = self.es.getDocIds(comm_res) comm_ids = list(comm_content.keys())
if merge_into_first(comm_res, "-Community Report-"):
comm_ids = comm_ids[0:1]
## Text content retrieval ## Text content retrieval
bqry = deepcopy(binary_query) condition = self.get_filters(req)
bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"])) condition.update({"knowledge_graph_kwd": ["text"]})
s = Search() txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
s = s.query(bqry)[0: 6] txt_res_fields = self.dataStore.getFields(txt_res, src)
s = s.to_dict() txt_ids = self.dataStore.getChunkIds(txt_res)
txt_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src) txt_content = merge_into_first(txt_res_fields, "-Original Content-")
txt_ids = self.es.getDocIds(txt_res) if txt_content:
if merge_into_first(txt_res, "-Original Content-"): txt_ids = list(txt_content.keys())
txt_ids = txt_ids[0:1]
return self.SearchResult( return self.SearchResult(
total=len(ent_ids) + len(comm_ids) + len(txt_ids), total=len(ent_ids) + len(comm_ids) + len(txt_ids),
ids=[*ent_ids, *comm_ids, *txt_ids], ids=[*ent_ids, *comm_ids, *txt_ids],
query_vector=q_vec, query_vector=q_vec,
aggregation=None,
highlight=None, highlight=None,
field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)}, field={**ent_content, **comm_content, **txt_content},
keywords=[] keywords=[]
) )

View File

@ -31,10 +31,13 @@ if __name__ == "__main__":
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api.settings import retrievaler
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in docs = [d["content_with_weight"] for d in
retrievaler.chunk_list(args.doc_id, args.tenant_id, max_count=6, fields=["content_with_weight"])] retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
graph = ex(docs) graph = ex(docs)
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT)) er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))

871
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -46,22 +46,23 @@ hanziconv = "0.3.2"
html-text = "0.6.2" html-text = "0.6.2"
httpx = "0.27.0" httpx = "0.27.0"
huggingface-hub = "^0.25.0" huggingface-hub = "^0.25.0"
infinity-emb = "0.0.51" infinity-sdk = "0.5.0.dev2"
infinity-emb = "^0.0.66"
itsdangerous = "2.1.2" itsdangerous = "2.1.2"
markdown = "3.6" markdown = "3.6"
markdown-to-json = "2.1.1" markdown-to-json = "2.1.1"
minio = "7.2.4" minio = "7.2.4"
mistralai = "0.4.2" mistralai = "0.4.2"
nltk = "3.9.1" nltk = "3.9.1"
numpy = "1.26.4" numpy = "^1.26.0"
ollama = "0.2.1" ollama = "0.2.1"
onnxruntime = "1.19.2" onnxruntime = "1.19.2"
openai = "1.45.0" openai = "1.45.0"
opencv-python = "4.10.0.84" opencv-python = "4.10.0.84"
opencv-python-headless = "4.10.0.84" opencv-python-headless = "4.10.0.84"
openpyxl = "3.1.2" openpyxl = "^3.1.0"
ormsgpack = "1.5.0" ormsgpack = "1.5.0"
pandas = "2.2.2" pandas = "^2.2.0"
pdfplumber = "0.10.4" pdfplumber = "0.10.4"
peewee = "3.17.1" peewee = "3.17.1"
pillow = "10.4.0" pillow = "10.4.0"
@ -70,7 +71,7 @@ psycopg2-binary = "2.9.9"
pyclipper = "1.3.0.post5" pyclipper = "1.3.0.post5"
pycryptodomex = "3.20.0" pycryptodomex = "3.20.0"
pypdf = "^5.0.0" pypdf = "^5.0.0"
pytest = "8.2.2" pytest = "^8.3.0"
python-dotenv = "1.0.1" python-dotenv = "1.0.1"
python-dateutil = "2.8.2" python-dateutil = "2.8.2"
python-pptx = "^1.0.2" python-pptx = "^1.0.2"
@ -86,7 +87,7 @@ ruamel-base = "1.0.0"
scholarly = "1.7.11" scholarly = "1.7.11"
scikit-learn = "1.5.0" scikit-learn = "1.5.0"
selenium = "4.22.0" selenium = "4.22.0"
setuptools = "70.0.0" setuptools = "^75.2.0"
shapely = "2.0.5" shapely = "2.0.5"
six = "1.16.0" six = "1.16.0"
strenum = "0.4.15" strenum = "0.4.15"
@ -115,6 +116,7 @@ pymysql = "^1.1.1"
mini-racer = "^0.12.4" mini-racer = "^0.12.4"
pyicu = "^2.13.1" pyicu = "^2.13.1"
flasgger = "^0.9.7.1" flasgger = "^0.9.7.1"
polars = "^1.9.0"
[tool.poetry.group.full] [tool.poetry.group.full]

View File

@ -20,6 +20,7 @@ from rag.nlp import tokenize, is_english
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, PptParser, PlainParser from deepdoc.parser import PdfParser, PptParser, PlainParser
from PyPDF2 import PdfReader as pdf2_read from PyPDF2 import PdfReader as pdf2_read
import json
class Ppt(PptParser): class Ppt(PptParser):
@ -107,9 +108,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
pn += from_page pn += from_page
d["image"] = img d["image"] = img
d["page_num_int"] = [pn + 1] d["page_num_list"] = json.dumps([pn + 1])
d["top_int"] = [0] d["top_list"] = json.dumps([0])
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])] d["position_list"] = json.dumps([(pn + 1, 0, img.size[0], 0, img.size[1])])
tokenize(d, txt, eng) tokenize(d, txt, eng)
res.append(d) res.append(d)
return res return res
@ -123,10 +124,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
pn += from_page pn += from_page
if img: if img:
d["image"] = img d["image"] = img
d["page_num_int"] = [pn + 1] d["page_num_list"] = json.dumps([pn + 1])
d["top_int"] = [0] d["top_list"] = json.dumps([0])
d["position_int"] = [ d["position_list"] = json.dumps([
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)] (pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)])
tokenize(d, txt, eng) tokenize(d, txt, eng)
res.append(d) res.append(d)
return res return res

View File

@ -74,7 +74,7 @@ class Excel(ExcelParser):
def trans_datatime(s): def trans_datatime(s):
try: try:
return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S") return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
except Exception as e: except Exception:
pass pass
@ -112,7 +112,7 @@ def column_data_type(arr):
continue continue
try: try:
arr[i] = trans[ty](str(arr[i])) arr[i] = trans[ty](str(arr[i]))
except Exception as e: except Exception:
arr[i] = None arr[i] = None
# if ty == "text": # if ty == "text":
# if len(arr) > 128 and uni / len(arr) < 0.1: # if len(arr) > 128 and uni / len(arr) < 0.1:
@ -182,7 +182,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
"datetime": "_dt", "datetime": "_dt",
"bool": "_kwd"} "bool": "_kwd"}
for df in dfs: for df in dfs:
for n in ["id", "_id", "index", "idx"]: for n in ["id", "index", "idx"]:
if n in df.columns: if n in df.columns:
del df[n] del df[n]
clmns = df.columns.values clmns = df.columns.values

View File

@ -1,280 +1,310 @@
# #
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json import json
import os import os
from collections import defaultdict import sys
from concurrent.futures import ThreadPoolExecutor import time
from copy import deepcopy import argparse
from collections import defaultdict
from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils import get_uuid from api.settings import retrievaler, docStoreConn
from api.utils.file_utils import get_project_base_directory from api.utils import get_uuid
from rag.nlp import tokenize, search from rag.nlp import tokenize, search
from rag.utils.es_conn import ELASTICSEARCH from ranx import evaluate
from ranx import evaluate import pandas as pd
import pandas as pd from tqdm import tqdm
from tqdm import tqdm
from ranx import Qrels, Run global max_docs
max_docs = sys.maxsize
class Benchmark: class Benchmark:
def __init__(self, kb_id): def __init__(self, kb_id):
e, self.kb = KnowledgebaseService.get_by_id(kb_id) self.kb_id = kb_id
self.similarity_threshold = self.kb.similarity_threshold e, self.kb = KnowledgebaseService.get_by_id(kb_id)
self.vector_similarity_weight = self.kb.vector_similarity_weight self.similarity_threshold = self.kb.similarity_threshold
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language) self.vector_similarity_weight = self.kb.vector_similarity_weight
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
def _get_benchmarks(self, query, dataset_idxnm, count=16): self.tenant_id = ''
self.index_name = ''
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold} self.initialized_index = False
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
return sres def _get_retrieval(self, qrels):
# Need to wait for the ES and Infinity index to be ready
def _get_retrieval(self, qrels, dataset_idxnm): time.sleep(20)
run = defaultdict(dict) run = defaultdict(dict)
query_list = list(qrels.keys()) query_list = list(qrels.keys())
for query in query_list: for query in query_list:
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
ranks = retrievaler.retrieval(query, self.embd_mdl, 0.0, self.vector_similarity_weight)
dataset_idxnm, [self.kb.id], 1, 30, if len(ranks["chunks"]) == 0:
0.0, self.vector_similarity_weight) print(f"deleted query: {query}")
for c in ranks["chunks"]: del qrels[query]
if "vector" in c: continue
del c["vector"] for c in ranks["chunks"]:
run[query][c["chunk_id"]] = c["similarity"] if "vector" in c:
del c["vector"]
return run run[query][c["chunk_id"]] = c["similarity"]
return run
def embedding(self, docs, batch_size=16):
vects = [] def embedding(self, docs, batch_size=16):
cnts = [d["content_with_weight"] for d in docs] vects = []
for i in range(0, len(cnts), batch_size): cnts = [d["content_with_weight"] for d in docs]
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size]) for i in range(0, len(cnts), batch_size):
vects.extend(vts.tolist()) vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
assert len(docs) == len(vects) vects.extend(vts.tolist())
for i, d in enumerate(docs): assert len(docs) == len(vects)
v = vects[i] vector_size = 0
d["q_%d_vec" % len(v)] = v for i, d in enumerate(docs):
return docs v = vects[i]
vector_size = len(v)
@staticmethod d["q_%d_vec" % len(v)] = v
def init_kb(index_name): return docs, vector_size
idxnm = search.index_name(index_name)
if ELASTICSEARCH.indexExist(idxnm): def init_index(self, vector_size: int):
ELASTICSEARCH.deleteIdx(search.index_name(index_name)) if self.initialized_index:
return
return ELASTICSEARCH.createIdx(idxnm, json.load( if docStoreConn.indexExist(self.index_name, self.kb_id):
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) docStoreConn.deleteIdx(self.index_name, self.kb_id)
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
def ms_marco_index(self, file_path, index_name): self.initialized_index = True
qrels = defaultdict(dict)
texts = defaultdict(dict) def ms_marco_index(self, file_path, index_name):
docs = [] qrels = defaultdict(dict)
filelist = os.listdir(file_path) texts = defaultdict(dict)
self.init_kb(index_name) docs_count = 0
docs = []
max_workers = int(os.environ.get('MAX_WORKERS', 3)) filelist = sorted(os.listdir(file_path))
exe = ThreadPoolExecutor(max_workers=max_workers)
threads = [] for fn in filelist:
if docs_count >= max_docs:
def slow_actions(es_docs, idx_nm): break
es_docs = self.embedding(es_docs) if not fn.endswith(".parquet"):
ELASTICSEARCH.bulk(es_docs, idx_nm) continue
return True data = pd.read_parquet(os.path.join(file_path, fn))
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + fn):
for dir in filelist: if docs_count >= max_docs:
data = pd.read_parquet(os.path.join(file_path, dir)) break
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir): query = data.iloc[i]['query']
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
query = data.iloc[i]['query'] d = {
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']): "id": get_uuid(),
d = { "kb_id": self.kb.id,
"id": get_uuid(), "docnm_kwd": "xxxxx",
"kb_id": self.kb.id, "doc_id": "ksksks"
"docnm_kwd": "xxxxx", }
"doc_id": "ksksks" tokenize(d, text, "english")
} docs.append(d)
tokenize(d, text, "english") texts[d["id"]] = text
docs.append(d) qrels[query][d["id"]] = int(rel)
texts[d["id"]] = text if len(docs) >= 32:
qrels[query][d["id"]] = int(rel) docs_count += len(docs)
if len(docs) >= 32: docs, vector_size = self.embedding(docs)
threads.append( self.init_index(vector_size)
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name))) docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = [] docs = []
threads.append( if docs:
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name))) docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir): docStoreConn.insert(docs, self.index_name, self.kb_id)
if not threads[i].result().output: return qrels, texts
print("Indexing error...")
def trivia_qa_index(self, file_path, index_name):
return qrels, texts qrels = defaultdict(dict)
texts = defaultdict(dict)
def trivia_qa_index(self, file_path, index_name): docs_count = 0
qrels = defaultdict(dict) docs = []
texts = defaultdict(dict) filelist = sorted(os.listdir(file_path))
docs = [] for fn in filelist:
filelist = os.listdir(file_path) if docs_count >= max_docs:
for dir in filelist: break
data = pd.read_parquet(os.path.join(file_path, dir)) if not fn.endswith(".parquet"):
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir): continue
query = data.iloc[i]['question'] data = pd.read_parquet(os.path.join(file_path, fn))
for rel, text in zip(data.iloc[i]["search_results"]['rank'], for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + fn):
data.iloc[i]["search_results"]['search_context']): if docs_count >= max_docs:
d = { break
"id": get_uuid(), query = data.iloc[i]['question']
"kb_id": self.kb.id, for rel, text in zip(data.iloc[i]["search_results"]['rank'],
"docnm_kwd": "xxxxx", data.iloc[i]["search_results"]['search_context']):
"doc_id": "ksksks" d = {
} "id": get_uuid(),
tokenize(d, text, "english") "kb_id": self.kb.id,
docs.append(d) "docnm_kwd": "xxxxx",
texts[d["id"]] = text "doc_id": "ksksks"
qrels[query][d["id"]] = int(rel) }
if len(docs) >= 32: tokenize(d, text, "english")
docs = self.embedding(docs) docs.append(d)
ELASTICSEARCH.bulk(docs, search.index_name(index_name)) texts[d["id"]] = text
docs = [] qrels[query][d["id"]] = int(rel)
if len(docs) >= 32:
docs = self.embedding(docs) docs_count += len(docs)
ELASTICSEARCH.bulk(docs, search.index_name(index_name)) docs, vector_size = self.embedding(docs)
return qrels, texts self.init_index(vector_size)
docStoreConn.insert(docs,self.index_name)
def miracl_index(self, file_path, corpus_path, index_name): docs = []
corpus_total = {} docs, vector_size = self.embedding(docs)
for corpus_file in os.listdir(corpus_path): self.init_index(vector_size)
tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True) docStoreConn.insert(docs, self.index_name)
for index, i in tmp_data.iterrows(): return qrels, texts
corpus_total[i['docid']] = i['text']
def miracl_index(self, file_path, corpus_path, index_name):
topics_total = {} corpus_total = {}
for topics_file in os.listdir(os.path.join(file_path, 'topics')): for corpus_file in os.listdir(corpus_path):
if 'test' in topics_file: tmp_data = pd.read_json(os.path.join(corpus_path, corpus_file), lines=True)
continue for index, i in tmp_data.iterrows():
tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query']) corpus_total[i['docid']] = i['text']
for index, i in tmp_data.iterrows():
topics_total[i['qid']] = i['query'] topics_total = {}
for topics_file in os.listdir(os.path.join(file_path, 'topics')):
qrels = defaultdict(dict) if 'test' in topics_file:
texts = defaultdict(dict) continue
docs = [] tmp_data = pd.read_csv(os.path.join(file_path, 'topics', topics_file), sep='\t', names=['qid', 'query'])
for qrels_file in os.listdir(os.path.join(file_path, 'qrels')): for index, i in tmp_data.iterrows():
if 'test' in qrels_file: topics_total[i['qid']] = i['query']
continue
qrels = defaultdict(dict)
tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t', texts = defaultdict(dict)
names=['qid', 'Q0', 'docid', 'relevance']) docs_count = 0
for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file): docs = []
query = topics_total[tmp_data.iloc[i]['qid']] for qrels_file in os.listdir(os.path.join(file_path, 'qrels')):
text = corpus_total[tmp_data.iloc[i]['docid']] if 'test' in qrels_file:
rel = tmp_data.iloc[i]['relevance'] continue
d = { if docs_count >= max_docs:
"id": get_uuid(), break
"kb_id": self.kb.id,
"docnm_kwd": "xxxxx", tmp_data = pd.read_csv(os.path.join(file_path, 'qrels', qrels_file), sep='\t',
"doc_id": "ksksks" names=['qid', 'Q0', 'docid', 'relevance'])
} for i in tqdm(range(len(tmp_data)), colour="green", desc="Indexing:" + qrels_file):
tokenize(d, text, 'english') if docs_count >= max_docs:
docs.append(d) break
texts[d["id"]] = text query = topics_total[tmp_data.iloc[i]['qid']]
qrels[query][d["id"]] = int(rel) text = corpus_total[tmp_data.iloc[i]['docid']]
if len(docs) >= 32: rel = tmp_data.iloc[i]['relevance']
docs = self.embedding(docs) d = {
ELASTICSEARCH.bulk(docs, search.index_name(index_name)) "id": get_uuid(),
docs = [] "kb_id": self.kb.id,
"docnm_kwd": "xxxxx",
docs = self.embedding(docs) "doc_id": "ksksks"
ELASTICSEARCH.bulk(docs, search.index_name(index_name)) }
tokenize(d, text, 'english')
return qrels, texts docs.append(d)
texts[d["id"]] = text
def save_results(self, qrels, run, texts, dataset, file_path): qrels[query][d["id"]] = int(rel)
keep_result = [] if len(docs) >= 32:
run_keys = list(run.keys()) docs_count += len(docs)
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"): docs, vector_size = self.embedding(docs)
key = run_keys[run_i] self.init_index(vector_size)
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key], docStoreConn.insert(docs, self.index_name)
'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")}) docs = []
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f: docs, vector_size = self.embedding(docs)
f.write('## Score For Every Query\n') self.init_index(vector_size)
for keep_result_i in keep_result: docStoreConn.insert(docs, self.index_name)
f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n') return qrels, texts
scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
scores = sorted(scores, key=lambda kk: kk[1]) def save_results(self, qrels, run, texts, dataset, file_path):
for score in scores[:10]: keep_result = []
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n') run_keys = list(run.keys())
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2) for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2) key = run_keys[run_i]
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!') keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
def __call__(self, dataset, file_path, miracl_corpus=''): keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
if dataset == "ms_marco_v1.1": with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1") f.write('## Score For Every Query\n')
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1") for keep_result_i in keep_result:
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"])) f.write('### query: ' + keep_result_i['query'] + ' ndcg@10:' + str(keep_result_i['ndcg@10']) + '\n')
self.save_results(qrels, run, texts, dataset, file_path) scores = [[i[0], i[1]] for i in keep_result_i['run'].items()]
if dataset == "trivia_qa": scores = sorted(scores, key=lambda kk: kk[1])
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa") for score in scores[:10]:
run = self._get_retrieval(qrels, "benchmark_trivia_qa") f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"])) json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
self.save_results(qrels, run, texts, dataset, file_path) json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
if dataset == "miracl": print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
'yo', 'zh']: def __call__(self, dataset, file_path, miracl_corpus=''):
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)): if dataset == "ms_marco_v1.1":
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!') self.tenant_id = "benchmark_ms_marco_v11"
continue self.index_name = search.index_name(self.tenant_id)
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')): qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!') run = self._get_retrieval(qrels)
continue print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')): self.save_results(qrels, run, texts, dataset, file_path)
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!') if dataset == "trivia_qa":
continue self.tenant_id = "benchmark_trivia_qa"
if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)): self.index_name = search.index_name(self.tenant_id)
print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!') qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
continue run = self._get_retrieval(qrels)
qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang), print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang), self.save_results(qrels, run, texts, dataset, file_path)
"benchmark_miracl_" + lang) if dataset == "miracl":
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang) for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"])) 'yo', 'zh']:
self.save_results(qrels, run, texts, dataset, file_path) if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang)):
print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang) + ' not found!')
continue
if __name__ == '__main__': if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels')):
print('*****************RAGFlow Benchmark*****************') print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'qrels') + 'not found!')
kb_id = input('Please input kb_id:\n') continue
ex = Benchmark(kb_id) if not os.path.isdir(os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics')):
dataset = input( print('Directory: ' + os.path.join(file_path, 'miracl-v1.0-' + lang, 'topics') + 'not found!')
'RAGFlow Benchmark Support:\n\tms_marco_v1.1:<https://huggingface.co/datasets/microsoft/ms_marco>\n\ttrivia_qa:<https://huggingface.co/datasets/mandarjoshi/trivia_qa>\n\tmiracl:<https://huggingface.co/datasets/miracl/miracl>\nPlease input dataset choice:\n') continue
if dataset in ['ms_marco_v1.1', 'trivia_qa']: if not os.path.isdir(os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang)):
if dataset == "ms_marco_v1.1": print('Directory: ' + os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang) + ' not found!')
print("Notice: Please provide the ms_marco_v1.1 dataset only. ms_marco_v2.1 is not supported!") continue
dataset_path = input('Please input ' + dataset + ' dataset path:\n') self.tenant_id = "benchmark_miracl_" + lang
ex(dataset, dataset_path) self.index_name = search.index_name(self.tenant_id)
elif dataset == 'miracl': self.initialized_index = False
dataset_path = input('Please input ' + dataset + ' dataset path:\n') qrels, texts = self.miracl_index(os.path.join(file_path, 'miracl-v1.0-' + lang),
corpus_path = input('Please input ' + dataset + '-corpus dataset path:\n') os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
ex(dataset, dataset_path, miracl_corpus=corpus_path) "benchmark_miracl_" + lang)
else: run = self._get_retrieval(qrels)
print("Dataset: ", dataset, "not supported!") print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
self.save_results(qrels, run, texts, dataset, file_path)
if __name__ == '__main__':
print('*****************RAGFlow Benchmark*****************')
parser = argparse.ArgumentParser(usage="benchmark.py <max_docs> <kb_id> <dataset> <dataset_path> [<miracl_corpus_path>])", description='RAGFlow Benchmark')
parser.add_argument('max_docs', metavar='max_docs', type=int, help='max docs to evaluate')
parser.add_argument('kb_id', metavar='kb_id', help='knowledgebase id')
parser.add_argument('dataset', metavar='dataset', help='dataset name, shall be one of ms_marco_v1.1(https://huggingface.co/datasets/microsoft/ms_marco), trivia_qa(https://huggingface.co/datasets/mandarjoshi/trivia_qa>), miracl(https://huggingface.co/datasets/miracl/miracl')
parser.add_argument('dataset_path', metavar='dataset_path', help='dataset path')
parser.add_argument('miracl_corpus_path', metavar='miracl_corpus_path', nargs='?', default="", help='miracl corpus path. Only needed when dataset is miracl')
args = parser.parse_args()
max_docs = args.max_docs
kb_id = args.kb_id
ex = Benchmark(kb_id)
dataset = args.dataset
dataset_path = args.dataset_path
if dataset == "ms_marco_v1.1" or dataset == "trivia_qa":
ex(dataset, dataset_path)
elif dataset == "miracl":
if len(args) < 5:
print('Please input the correct parameters!')
exit(1)
miracl_corpus_path = args[4]
ex(dataset, dataset_path, miracl_corpus=args.miracl_corpus_path)
else:
print("Dataset: ", dataset, "not supported!")

View File

@ -25,6 +25,7 @@ import roman_numbers as r
from word2number import w2n from word2number import w2n
from cn2an import cn2an from cn2an import cn2an
from PIL import Image from PIL import Image
import json
all_codecs = [ all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', 'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
@ -51,12 +52,12 @@ def find_codec(blob):
try: try:
blob[:1024].decode(c) blob[:1024].decode(c)
return c return c
except Exception as e: except Exception:
pass pass
try: try:
blob.decode(c) blob.decode(c)
return c return c
except Exception as e: except Exception:
pass pass
return "utf-8" return "utf-8"
@ -241,7 +242,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
d["image"], poss = pdf_parser.crop(ck, need_position=True) d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss) add_positions(d, poss)
ck = pdf_parser.remove_tag(ck) ck = pdf_parser.remove_tag(ck)
except NotImplementedError as e: except NotImplementedError:
pass pass
tokenize(d, ck, eng) tokenize(d, ck, eng)
res.append(d) res.append(d)
@ -289,13 +290,16 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
def add_positions(d, poss): def add_positions(d, poss):
if not poss: if not poss:
return return
d["page_num_int"] = [] page_num_list = []
d["position_int"] = [] position_list = []
d["top_int"] = [] top_list = []
for pn, left, right, top, bottom in poss: for pn, left, right, top, bottom in poss:
d["page_num_int"].append(int(pn + 1)) page_num_list.append(int(pn + 1))
d["top_int"].append(int(top)) top_list.append(int(top))
d["position_int"].append((int(pn + 1), int(left), int(right), int(top), int(bottom))) position_list.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
d["page_num_list"] = json.dumps(page_num_list)
d["position_list"] = json.dumps(position_list)
d["top_list"] = json.dumps(top_list)
def remove_contents_table(sections, eng=False): def remove_contents_table(sections, eng=False):

View File

@ -15,20 +15,25 @@
# #
import json import json
import math
import re import re
import logging import logging
import copy from rag.utils.doc_store_conn import MatchTextExpr
from elasticsearch_dsl import Q
from rag.nlp import rag_tokenizer, term_weight, synonym from rag.nlp import rag_tokenizer, term_weight, synonym
class EsQueryer:
def __init__(self, es): class FulltextQueryer:
def __init__(self):
self.tw = term_weight.Dealer() self.tw = term_weight.Dealer()
self.es = es
self.syn = synonym.Dealer() self.syn = synonym.Dealer()
self.flds = ["ask_tks^10", "ask_small_tks"] self.query_fields = [
"title_tks^10",
"title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"content_ltks^2",
"content_sm_ltks",
]
@staticmethod @staticmethod
def subSpecialChar(line): def subSpecialChar(line):
@ -43,12 +48,15 @@ class EsQueryer:
for t in arr: for t in arr:
if not re.match(r"[a-zA-Z]+$", t): if not re.match(r"[a-zA-Z]+$", t):
e += 1 e += 1
return e * 1. / len(arr) >= 0.7 return e * 1.0 / len(arr) >= 0.7
@staticmethod @staticmethod
def rmWWW(txt): def rmWWW(txt):
patts = [ patts = [
(r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""), (
r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ") (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ")
] ]
@ -56,16 +64,16 @@ class EsQueryer:
txt = re.sub(r, p, txt, flags=re.IGNORECASE) txt = re.sub(r, p, txt, flags=re.IGNORECASE)
return txt return txt
def question(self, txt, tbl="qa", min_match="60%"): def question(self, txt, tbl="qa", min_match:float=0.6):
txt = re.sub( txt = re.sub(
r"[ :\r\n\t,,。??/`!&\^%%]+", r"[ :\r\n\t,,。??/`!&\^%%]+",
" ", " ",
rag_tokenizer.tradi2simp( rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
rag_tokenizer.strQ2B( ).strip()
txt.lower()))).strip() txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt): if not self.isChinese(txt):
txt = EsQueryer.rmWWW(txt) txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split(" ") tks = rag_tokenizer.tokenize(txt).split(" ")
tks_w = self.tw.weights(tks) tks_w = self.tw.weights(tks)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w] tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
@ -73,14 +81,20 @@ class EsQueryer:
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk] tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk] q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
for i in range(1, len(tks_w)): for i in range(1, len(tks_w)):
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q: if not q:
q.append(txt) q.append(txt)
return Q("bool", query = " ".join(q)
must=Q("query_string", fields=self.flds, return MatchTextExpr(
type="best_fields", query=" ".join(q), self.query_fields, query, 100
boost=1)#, minimum_should_match=min_match) ), tks
), list(set([t for t in txt.split(" ") if t]))
def need_fine_grained_tokenize(tk): def need_fine_grained_tokenize(tk):
if len(tk) < 3: if len(tk) < 3:
@ -89,7 +103,7 @@ class EsQueryer:
return False return False
return True return True
txt = EsQueryer.rmWWW(txt) txt = FulltextQueryer.rmWWW(txt)
qs, keywords = [], [] qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split(" "): for tt in self.tw.split(txt)[:256]: # .split(" "):
if not tt: if not tt:
@ -101,65 +115,71 @@ class EsQueryer:
logging.info(json.dumps(twts, ensure_ascii=False)) logging.info(json.dumps(twts, ensure_ascii=False))
tms = [] tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1): for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = rag_tokenizer.fine_grained_tokenize(tk).split(" ") if need_fine_grained_tokenize(tk) else [] sm = (
rag_tokenizer.fine_grained_tokenize(tk).split(" ")
if need_fine_grained_tokenize(tk)
else []
)
sm = [ sm = [
re.sub( re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+", r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"", "",
m) for m in sm] m,
sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1] )
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1] sm = [m for m in sm if len(m) > 1]
keywords.append(re.sub(r"[ \\\"']+", "", tk)) keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm) keywords.extend(sm)
if len(keywords) >= 12: break if len(keywords) >= 12:
break
tk_syns = self.syn.lookup(tk) tk_syns = self.syn.lookup(tk)
tk = EsQueryer.subSpecialChar(tk) tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0: if tk.find(" ") > 0:
tk = "\"%s\"" % tk tk = '"%s"' % tk
if tk_syns: if tk_syns:
tk = f"({tk} %s)" % " ".join(tk_syns) tk = f"({tk} %s)" % " ".join(tk_syns)
if sm: if sm:
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % ( tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
" ".join(sm), " ".join(sm))
if tk.strip(): if tk.strip():
tms.append((tk, w)) tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms]) tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1: if len(twts) > 1:
tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts])) tms += ' ("%s"~4)^1.5' % (" ".join([t for t, _ in twts]))
if re.match(r"[0-9a-z ]+$", tt): if re.match(r"[0-9a-z ]+$", tt):
tms = f"(\"{tt}\" OR \"%s\")" % rag_tokenizer.tokenize(tt) tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
syns = " OR ".join( syns = " OR ".join(
["\"%s\"^0.7" % EsQueryer.subSpecialChar(rag_tokenizer.tokenize(s)) for s in syns]) [
'"%s"^0.7'
% FulltextQueryer.subSpecialChar(rag_tokenizer.tokenize(s))
for s in syns
]
)
if syns: if syns:
tms = f"({tms})^5 OR ({syns})^0.7" tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms) qs.append(tms)
flds = copy.deepcopy(self.flds)
mst = []
if qs: if qs:
mst.append( query = " OR ".join([f"({t})" for t in qs if t])
Q("query_string", fields=flds, type="best_fields", return MatchTextExpr(
query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match) self.query_fields, query, 100, {"minimum_should_match": min_match}
) ), keywords
return None, keywords
return Q("bool", def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
must=mst,
), list(set(keywords))
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np import numpy as np
sims = CosineSimilarity([avec], bvecs) sims = CosineSimilarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss) tksim = self.token_similarity(atks, btkss)
return np.array(sims[0]) * vtweight + \ return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss): def token_similarity(self, atks, btkss):
def toDict(tks): def toDict(tks):

View File

@ -14,34 +14,25 @@
# limitations under the License. # limitations under the License.
# #
import json
import re import re
from copy import deepcopy import json
from elasticsearch_dsl import Q, Search
from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass
from rag.settings import es_logger from rag.settings import doc_store_logger
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.nlp import rag_tokenizer, query, is_english from rag.nlp import rag_tokenizer, query
import numpy as np import numpy as np
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
def index_name(uid): return f"ragflow_{uid}" def index_name(uid): return f"ragflow_{uid}"
class Dealer: class Dealer:
def __init__(self, es): def __init__(self, dataStore: DocStoreConnection):
self.qryr = query.EsQueryer(es) self.qryr = query.FulltextQueryer()
self.qryr.flds = [ self.dataStore = dataStore
"title_tks^10",
"title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"content_ltks^2",
"content_sm_ltks"]
self.es = es
@dataclass @dataclass
class SearchResult: class SearchResult:
@ -54,170 +45,99 @@ class Dealer:
keywords: Optional[List[str]] = None keywords: Optional[List[str]] = None
group_docs: List[List] = None group_docs: List[List] = None
def _vector(self, txt, emb_mdl, sim=0.8, topk=10): def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, c = emb_mdl.encode_queries(txt) qv, _ = emb_mdl.encode_queries(txt)
return { embedding_data = [float(v) for v in qv]
"field": "q_%d_vec" % len(qv), vector_column_name = f"q_{len(embedding_data)}_vec"
"k": topk, return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
"similarity": sim,
"num_candidates": topk * 2,
"query_vector": [float(v) for v in qv]
}
def _add_filters(self, bqry, req): def get_filters(self, req):
if req.get("kb_ids"): condition = dict()
bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items():
if req.get("doc_ids"): if key in req and req[key] is not None:
bqry.filter.append(Q("terms", doc_id=req["doc_ids"])) condition[field] = req[key]
if req.get("knowledge_graph_kwd"): # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns.
bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"])) for key in ["knowledge_graph_kwd"]:
if "available_int" in req: if key in req and req[key] is not None:
if req["available_int"] == 0: condition[key] = req[key]
bqry.filter.append(Q("range", available_int={"lt": 1})) return condition
else:
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry
def search(self, req, idxnms, emb_mdl=None, highlight=False): def search(self, req, idx_names: list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
qst = req.get("question", "") filters = self.get_filters(req)
bqry, keywords = self.qryr.question(qst, min_match="30%") orderBy = OrderByExpr()
bqry = self._add_filters(bqry, req)
bqry.boost = 0.05
s = Search()
pg = int(req.get("page", 1)) - 1 pg = int(req.get("page", 1)) - 1
topk = int(req.get("topk", 1024)) topk = int(req.get("topk", 1024))
ps = int(req.get("size", topk)) ps = int(req.get("size", topk))
offset, limit = pg * ps, (pg + 1) * ps
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd", "doc_id", "position_list", "knowledge_graph_kwd",
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) "available_int", "content_with_weight"])
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks")
s = s.highlight("title_ltks")
if not qst:
if not req.get("sort"):
s = s.sort(
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
{"create_timestamp_flt": {
"order": "desc", "unmapped_type": "float"}}
)
else:
s = s.sort(
{"page_num_int": {"order": "asc", "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}},
{"top_int": {"order": "asc", "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}},
#{"create_time": {"order": "desc", "unmapped_type": "date"}},
{"create_timestamp_flt": {
"order": "desc", "unmapped_type": "float"}}
)
if qst:
s = s.highlight_options(
fragment_size=120,
number_of_fragments=5,
boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、"
)
s = s.to_dict()
q_vec = []
if req.get("vector"):
assert emb_mdl, "No embedding model selected"
s["knn"] = self._vector(
qst, emb_mdl, req.get(
"similarity", 0.1), topk)
s["knn"]["filter"] = bqry.to_dict()
if not highlight and "highlight" in s:
del s["highlight"]
q_vec = s["knn"]["query_vector"]
es_logger.info("【Q】: {}".format(json.dumps(s)))
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
if req.get("doc_ids"):
bqry = Q("bool", must=[])
bqry = self._add_filters(bqry, req)
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
es_logger.info("【Q】: {}".format(json.dumps(s)))
kwds = set([]) kwds = set([])
for k in keywords:
kwds.add(k)
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
aggs = self.getAggregation(res, "docnm_kwd") qst = req.get("question", "")
q_vec = []
if not qst:
if req.get("sort"):
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"] if highlight else []
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src.append(f"q_{len(q_vec)}_vec")
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
if total == 0:
matchText, _ = self.qryr.question(qst, min_match=0.1)
if "doc_ids" in filters:
del filters["doc_ids"]
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
kwds.add(k)
for kk in rag_tokenizer.fine_grained_tokenize(k).split(" "):
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
doc_store_logger.info(f"TOTAL: {total}")
ids=self.dataStore.getChunkIds(res)
keywords=list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
return self.SearchResult( return self.SearchResult(
total=self.es.getTotal(res), total=total,
ids=self.es.getDocIds(res), ids=ids,
query_vector=q_vec, query_vector=q_vec,
aggregation=aggs, aggregation=aggs,
highlight=self.getHighlight(res, keywords, "content_with_weight"), highlight=highlight,
field=self.getFields(res, src), field=self.dataStore.getFields(res, src),
keywords=list(kwds) keywords=keywords
) )
def getAggregation(self, res, g):
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
return
bkts = res["aggregations"]["aggs_" + g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
def getHighlight(self, res, keywords, fieldnm):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split(" ")):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def getFields(self, sres, flds):
res = {}
if not flds:
return {}
for d in self.es.getSource(sres):
m = {n: d.get(n) for n in flds if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, type([])):
m[n] = "\t".join([str(vv) if not isinstance(
vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v])
continue
if not isinstance(v, type("")):
m[n] = str(m[n])
#if n.find("tks") > 0:
# m[n] = rmSpace(m[n])
if m:
res[d["id"]] = m
return res
@staticmethod @staticmethod
def trans2floats(txt): def trans2floats(txt):
return [float(t) for t in txt.split("\t")] return [float(t) for t in txt.split("\t")]
@ -260,7 +180,7 @@ class Dealer:
continue continue
idx.append(i) idx.append(i)
pieces_.append(t) pieces_.append(t)
es_logger.info("{} => {}".format(answer, pieces_)) doc_store_logger.info("{} => {}".format(answer, pieces_))
if not pieces_: if not pieces_:
return answer, set([]) return answer, set([])
@ -281,7 +201,7 @@ class Dealer:
chunks_tks, chunks_tks,
tkweight, vtweight) tkweight, vtweight)
mx = np.max(sim) * 0.99 mx = np.max(sim) * 0.99
es_logger.info("{} SIM: {}".format(pieces_[i], mx)) doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx))
if mx < thr: if mx < thr:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(
@ -309,9 +229,15 @@ class Dealer:
def rerank(self, sres, query, tkweight=0.3, def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"): vtweight=0.7, cfield="content_ltks"):
_, keywords = self.qryr.question(query) _, keywords = self.qryr.question(query)
ins_embd = [ vector_size = len(sres.query_vector)
Dealer.trans2floats( vector_column = f"q_{vector_size}_vec"
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] zero_vector = [0.0] * vector_size
ins_embd = []
for chunk_id in sres.ids:
vector = sres.field[chunk_id].get(vector_column, zero_vector)
if isinstance(vector, str):
vector = [float(v) for v in vector.split("\t")]
ins_embd.append(vector)
if not ins_embd: if not ins_embd:
return [], [], [] return [], [], []
@ -377,7 +303,7 @@ class Dealer:
if isinstance(tenant_ids, str): if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",") tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight) sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
ranks["total"] = sres.total ranks["total"] = sres.total
if page <= RERANK_PAGE_LIMIT: if page <= RERANK_PAGE_LIMIT:
@ -393,6 +319,8 @@ class Dealer:
idx = list(range(len(sres.ids))) idx = list(range(len(sres.ids)))
dim = len(sres.query_vector) dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
for i in idx: for i in idx:
if sim[i] < similarity_threshold: if sim[i] < similarity_threshold:
break break
@ -401,34 +329,32 @@ class Dealer:
continue continue
break break
id = sres.ids[i] id = sres.ids[i]
dnm = sres.field[id]["docnm_kwd"] chunk = sres.field[id]
did = sres.field[id]["doc_id"] dnm = chunk["docnm_kwd"]
did = chunk["doc_id"]
position_list = chunk.get("position_list", "[]")
if not position_list:
position_list = "[]"
d = { d = {
"chunk_id": id, "chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"], "content_ltks": chunk["content_ltks"],
"content_with_weight": sres.field[id]["content_with_weight"], "content_with_weight": chunk["content_with_weight"],
"doc_id": sres.field[id]["doc_id"], "doc_id": chunk["doc_id"],
"docnm_kwd": dnm, "docnm_kwd": dnm,
"kb_id": sres.field[id]["kb_id"], "kb_id": chunk["kb_id"],
"important_kwd": sres.field[id].get("important_kwd", []), "important_kwd": chunk.get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""), "image_id": chunk.get("img_id", ""),
"similarity": sim[i], "similarity": sim[i],
"vector_similarity": vsim[i], "vector_similarity": vsim[i],
"term_similarity": tsim[i], "term_similarity": tsim[i],
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))), "vector": chunk.get(vector_column, zero_vector),
"positions": sres.field[id].get("position_int", "").split("\t") "positions": json.loads(position_list)
} }
if highlight: if highlight:
if id in sres.highlight: if id in sres.highlight:
d["highlight"] = rmSpace(sres.highlight[id]) d["highlight"] = rmSpace(sres.highlight[id])
else: else:
d["highlight"] = d["content_with_weight"] d["highlight"] = d["content_with_weight"]
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
ranks["chunks"].append(d) ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]: if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
@ -442,39 +368,11 @@ class Dealer:
return ranks return ranks
def sql_retrieval(self, sql, fetch_size=128, format="json"): def sql_retrieval(self, sql, fetch_size=128, format="json"):
from api.settings import chat_logger tbl = self.dataStore.sql(sql, fetch_size, format)
sql = re.sub(r"[ `]+", " ", sql) return tbl
sql = sql.replace("%", "")
es_logger.info(f"Get es sql: {sql}")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces: def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
sql = sql.replace(p, r, 1) condition = {"doc_id": doc_id}
chat_logger.info(f"To es: {sql}") res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, max_count, index_name(tenant_id), kb_ids)
dict_chunks = self.dataStore.getFields(res, fields)
try: return dict_chunks.values()
tbl = self.es.sql(sql, fetch_size, format)
return tbl
except Exception as e:
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
return {"error": str(e)}
def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
s = Search()
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
s = s.to_dict()
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
res = []
for index, chunk in enumerate(es_res['hits']['hits']):
res.append({fld: chunk['_source'].get(fld) for fld in fields})
return res

View File

@ -25,12 +25,13 @@ RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
SUBPROCESS_STD_LOG_NAME = "std.log" SUBPROCESS_STD_LOG_NAME = "std.log"
ES = get_base_config("es", {}) ES = get_base_config("es", {})
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
AZURE = get_base_config("azure", {}) AZURE = get_base_config("azure", {})
S3 = get_base_config("s3", {}) S3 = get_base_config("s3", {})
MINIO = decrypt_database_config(name="minio") MINIO = decrypt_database_config(name="minio")
try: try:
REDIS = decrypt_database_config(name="redis") REDIS = decrypt_database_config(name="redis")
except Exception as e: except Exception:
REDIS = {} REDIS = {}
pass pass
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
@ -44,7 +45,7 @@ LoggerFactory.set_directory(
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 30 LoggerFactory.LEVEL = 30
es_logger = getLogger("es") doc_store_logger = getLogger("doc_store")
minio_logger = getLogger("minio") minio_logger = getLogger("minio")
s3_logger = getLogger("s3") s3_logger = getLogger("s3")
azure_logger = getLogger("azure") azure_logger = getLogger("azure")
@ -53,7 +54,7 @@ chunk_logger = getLogger("chunk_logger")
database_logger = getLogger("database") database_logger = getLogger("database")
formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s") formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
for logger in [es_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]: for logger in [doc_store_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
for handler in logger.handlers: for handler in logger.handlers:
handler.setFormatter(fmt=formatter) handler.setFormatter(fmt=formatter)

View File

@ -31,7 +31,6 @@ from timeit import default_timer as timer
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from elasticsearch_dsl import Q
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.dialog_service import keyword_extraction, question_proposal from api.db.services.dialog_service import keyword_extraction, question_proposal
@ -39,8 +38,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler from api.settings import retrievaler, docStoreConn
from api.utils.file_utils import get_project_base_directory
from api.db.db_models import close_connection from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
@ -48,7 +46,6 @@ from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as
from rag.settings import database_logger, SVR_QUEUE_NAME from rag.settings import database_logger, SVR_QUEUE_NAME
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import rmSpace, num_tokens_from_string from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.redis_conn import REDIS_CONN, Payload from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
@ -126,7 +123,7 @@ def collect():
return pd.DataFrame() return pd.DataFrame()
tasks = TaskService.get_tasks(msg["id"]) tasks = TaskService.get_tasks(msg["id"])
if not tasks: if not tasks:
cron_logger.warn("{} empty task!".format(msg["id"])) cron_logger.warning("{} empty task!".format(msg["id"]))
return [] return []
tasks = pd.DataFrame(tasks) tasks = pd.DataFrame(tasks)
@ -187,7 +184,7 @@ def build(row):
docs = [] docs = []
doc = { doc = {
"doc_id": row["doc_id"], "doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])] "kb_id": str(row["kb_id"])
} }
el = 0 el = 0
for ck in cks: for ck in cks:
@ -196,10 +193,14 @@ def build(row):
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update((ck["content_with_weight"] + md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8")) str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest() d["id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if not d.get("image"): if not d.get("image"):
d["img_id"] = ""
d["page_num_list"] = json.dumps([])
d["position_list"] = json.dumps([])
d["top_list"] = json.dumps([])
docs.append(d) docs.append(d)
continue continue
@ -211,13 +212,13 @@ def build(row):
d["image"].save(output_buffer, format='JPEG') d["image"].save(output_buffer, format='JPEG')
st = timer() st = timer()
STORAGE_IMPL.put(row["kb_id"], d["_id"], output_buffer.getvalue()) STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st el += timer() - st
except Exception as e: except Exception as e:
cron_logger.error(str(e)) cron_logger.error(str(e))
traceback.print_exc() traceback.print_exc()
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
del d["image"] del d["image"]
docs.append(d) docs.append(d)
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el)) cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
@ -245,12 +246,9 @@ def build(row):
return docs return docs
def init_kb(row): def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"]) idxnm = search.index_name(row["tenant_id"])
if ELASTICSEARCH.indexExist(idxnm): return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
return
return ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
def embedding(docs, mdl, parser_config=None, callback=None): def embedding(docs, mdl, parser_config=None, callback=None):
@ -288,17 +286,20 @@ def embedding(docs, mdl, parser_config=None, callback=None):
cnts) if len(tts) == len(cnts) else cnts cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(docs) assert len(vects) == len(docs)
vector_size = 0
for i, d in enumerate(docs): for i, d in enumerate(docs):
v = vects[i].tolist() v = vects[i].tolist()
vector_size = len(v)
d["q_%d_vec" % len(v)] = v d["q_%d_vec" % len(v)] = v
return tk_count return tk_count, vector_size
def run_raptor(row, chat_mdl, embd_mdl, callback=None): def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vts, _ = embd_mdl.encode(["ok"]) vts, _ = embd_mdl.encode(["ok"])
vctr_nm = "q_%d_vec" % len(vts[0]) vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size
chunks = [] chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]): for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor( raptor = Raptor(
@ -323,7 +324,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update((content + str(d["doc_id"])).encode("utf-8")) md5.update((content + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest() d["id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
d[vctr_nm] = vctr.tolist() d[vctr_nm] = vctr.tolist()
@ -332,7 +333,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
res.append(d) res.append(d)
tk_count += num_tokens_from_string(content) tk_count += num_tokens_from_string(content)
return res, tk_count return res, tk_count, vector_size
def main(): def main():
@ -352,7 +353,7 @@ def main():
if r.get("task_type", "") == "raptor": if r.get("task_type", "") == "raptor":
try: try:
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"]) chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback) cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
except Exception as e: except Exception as e:
callback(-1, msg=str(e)) callback(-1, msg=str(e))
cron_logger.error(str(e)) cron_logger.error(str(e))
@ -373,7 +374,7 @@ def main():
len(cks)) len(cks))
st = timer() st = timer()
try: try:
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
except Exception as e: except Exception as e:
callback(-1, "Embedding error:{}".format(str(e))) callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e)) cron_logger.error(str(e))
@ -381,26 +382,25 @@ def main():
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st)) callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
init_kb(r) # cron_logger.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
chunk_count = len(set([c["_id"] for c in cks])) init_kb(r, vector_size)
chunk_count = len(set([c["id"] for c in cks]))
st = timer() st = timer()
es_r = "" es_r = ""
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"])) es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
if b % 128 == 0: if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r: if es_r:
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!") callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
ELASTICSEARCH.deleteByQuery( docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) cron_logger.error('Insert chunk error: ' + str(es_r))
cron_logger.error(str(es_r))
else: else:
if TaskService.do_cancel(r["id"]): if TaskService.do_cancel(r["id"]):
ELASTICSEARCH.deleteByQuery( docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
continue continue
callback(1., "Done!") callback(1., "Done!")
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(

251
rag/utils/doc_store_conn.py Normal file
View File

@ -0,0 +1,251 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
from dataclasses import dataclass
import numpy as np
import polars as pl
from typing import List, Dict
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = Union[list, np.ndarray]
@dataclass
class SparseVector:
indices: list[int]
values: Union[list[float], list[int], None] = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
def to_dict_old(self):
d = {"indices": self.indices}
if self.values is not None:
d["values"] = self.values
return d
def to_dict(self):
if self.values is None:
raise ValueError("SparseVector.values is None")
result = {}
for i, v in zip(self.indices, self.values):
result[str(i)] = v
return result
@staticmethod
def from_dict(d):
return SparseVector(d["indices"], d.get("values"))
def __str__(self):
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
def __init__(
self,
fields: str,
matching_text: str,
topn: int,
extra_options: dict = dict(),
):
self.fields = fields
self.matching_text = matching_text
self.topn = topn
self.extra_options = extra_options
class MatchDenseExpr(ABC):
def __init__(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
self.embedding_data_type = embedding_data_type
self.distance_type = distance_type
self.topn = topn
self.extra_options = extra_options
class MatchSparseExpr(ABC):
def __init__(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: Optional[dict] = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
self.distance_type = distance_type
self.topn = topn
self.opt_params = opt_params
class MatchTensorExpr(ABC):
def __init__(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: Optional[dict] = None,
):
self.column_name = column_name
self.query_data = query_data
self.query_data_type = query_data_type
self.topn = topn
self.extra_option = extra_option
class FusionExpr(ABC):
def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = Union[
MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
]
class OrderByExpr(ABC):
def __init__(self):
self.fields = list()
def asc(self, field: str):
self.fields.append((field, 0))
return self
def desc(self, field: str):
self.fields.append((field, 1))
return self
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
"""
Return the type of the database.
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def health(self) -> dict:
"""
Return the health status of the database.
"""
raise NotImplementedError("Not implemented")
"""
Table operations
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
"""
Check if an index with given name exists
"""
raise NotImplementedError("Not implemented")
"""
CRUD operations
"""
@abstractmethod
def search(
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
) -> list[dict] | pl.DataFrame:
"""
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
@abstractmethod
def getTotal(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getChunkIds(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: List[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def getAggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""
raise NotImplementedError("Not implemented")

View File

@ -1,29 +1,29 @@
import re import re
import json import json
import time import time
import copy import os
from typing import List, Dict
import elasticsearch import elasticsearch
from elastic_transport import ConnectionTimeout import copy
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from elasticsearch_dsl import UpdateByQuery, Search, Index from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from rag.settings import es_logger from elastic_transport import ConnectionTimeout
from rag.settings import doc_store_logger
from rag import settings from rag import settings
from rag.utils import singleton from rag.utils import singleton
from api.utils.file_utils import get_project_base_directory
import polars as pl
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from rag.nlp import is_english, rag_tokenizer
es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__)) doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
@singleton @singleton
class ESConnection: class ESConnection(DocStoreConnection):
def __init__(self): def __init__(self):
self.info = {} self.info = {}
self.conn()
self.idxnm = settings.ES.get("index_name", "")
if not self.es.ping():
raise Exception("Can't connect to ES cluster")
def conn(self):
for _ in range(10): for _ in range(10):
try: try:
self.es = Elasticsearch( self.es = Elasticsearch(
@ -34,390 +34,317 @@ class ESConnection:
) )
if self.es: if self.es:
self.info = self.es.info() self.info = self.es.info()
es_logger.info("Connect to es.") doc_store_logger.info("Connect to es.")
break break
except Exception as e: except Exception as e:
es_logger.error("Fail to connect to es: " + str(e)) doc_store_logger.error("Fail to connect to es: " + str(e))
time.sleep(1) time.sleep(1)
if not self.es.ping():
def version(self): raise Exception("Can't connect to ES cluster")
v = self.info.get("version", {"number": "5.6"}) v = self.info.get("version", {"number": "5.6"})
v = v["number"].split(".")[0] v = v["number"].split(".")[0]
return int(v) >= 7 if int(v) < 8:
raise Exception(f"ES version must be greater than or equal to 8, current version: {v}")
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
self.mapping = json.load(open(fp_mapping, "r"))
def health(self): """
return dict(self.es.cluster.health()) Database operations
"""
def dbType(self) -> str:
return "elasticsearch"
def upsert(self, df, idxnm=""): def health(self) -> dict:
res = [] return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
for d in df:
id = d["id"]
del d["id"]
d = {"doc": d, "doc_as_upsert": "true"}
T = False
for _ in range(10):
try:
if not self.version():
r = self.es.update(
index=(
self.idxnm if not idxnm else idxnm),
body=d,
id=id,
doc_type="doc",
refresh=True,
retry_on_conflict=100)
else:
r = self.es.update(
index=(
self.idxnm if not idxnm else idxnm),
body=d,
id=id,
refresh=True,
retry_on_conflict=100)
es_logger.info("Successfully upsert: %s" % id)
T = True
break
except Exception as e:
es_logger.warning("Fail to index: " +
json.dumps(d, ensure_ascii=False) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
T = False
if not T: """
res.append(d) Table operations
es_logger.error( """
"Fail to index: " + def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
re.sub( if self.indexExist(indexName, knowledgebaseId):
"[\r\n]",
"",
json.dumps(
d,
ensure_ascii=False)))
d["id"] = id
d["_index"] = self.idxnm
if not res:
return True return True
return False try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=indexName,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception as e:
doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e)))
def bulk(self, df, idx_nm=None): def deleteIdx(self, indexName: str, knowledgebaseId: str):
ids, acts = {}, [] try:
for d in df: return self.es.indices.delete(indexName, allow_no_indices=True)
id = d["id"] if "id" in d else d["_id"] except Exception as e:
ids[id] = copy.deepcopy(d) doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e)))
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
if "id" in d:
del d["id"]
if "_id" in d:
del d["_id"]
acts.append(
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
acts.append({"doc": d, "doc_as_upsert": "true"})
res = [] def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
for _ in range(100): s = Index(indexName, self.es)
try:
if elasticsearch.__version__[0] < 8:
r = self.es.bulk(
index=(
self.idxnm if not idx_nm else idx_nm),
body=acts,
refresh=False,
timeout="600s")
else:
r = self.es.bulk(index=(self.idxnm if not idx_nm else
idx_nm), operations=acts,
refresh=False, timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for it in r["items"]:
if "error" in it["update"]:
res.append(str(it["update"]["_id"]) +
":" + str(it["update"]["error"]))
return res
except Exception as e:
es_logger.warn("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
return res
def bulk4script(self, df):
ids, acts = {}, []
for d in df:
id = d["id"]
ids[id] = copy.deepcopy(d["raw"])
acts.append({"update": {"_id": id, "_index": self.idxnm}})
acts.append(d["script"])
es_logger.info("bulk upsert: %s" % id)
res = []
for _ in range(10):
try:
if not self.version():
r = self.es.bulk(
index=self.idxnm,
body=acts,
refresh=False,
timeout="600s",
doc_type="doc")
else:
r = self.es.bulk(
index=self.idxnm,
body=acts,
refresh=False,
timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for it in r["items"]:
if "error" in it["update"]:
res.append(str(it["update"]["_id"]))
return res
except Exception as e:
es_logger.warning("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
return res
def rm(self, d):
for _ in range(10):
try:
if not self.version():
r = self.es.delete(
index=self.idxnm,
id=d["id"],
doc_type="doc",
refresh=True)
else:
r = self.es.delete(
index=self.idxnm,
id=d["id"],
refresh=True,
doc_type="_doc")
es_logger.info("Remove %s" % d["id"])
return True
except Exception as e:
es_logger.warn("Fail to delete: " + str(d) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return True
self.conn()
es_logger.error("Fail to delete: " + str(d))
return False
def search(self, q, idxnms=None, src=False, timeout="2s"):
if not isinstance(q, dict):
q = Search().query(q).to_dict()
if isinstance(idxnms, str):
idxnms = idxnms.split(",")
for i in range(3):
try:
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
body=q,
timeout=timeout,
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=src)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
return res
except Exception as e:
es_logger.error(
"ES search exception: " +
str(e) +
"【Q】" +
str(q))
if str(e).find("Timeout") > 0:
continue
raise e
es_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
for i in range(3):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
return res
except ConnectionTimeout as e:
es_logger.error("Timeout【Q】" + sql)
continue
except Exception as e:
raise e
es_logger.error("ES search timeout for 3 times!")
raise ConnectionTimeout()
def get(self, doc_id, idxnm=None):
for i in range(3):
try:
res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
id=doc_id)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
return res
except Exception as e:
es_logger.error(
"ES get exception: " +
str(e) +
"【Q】" +
doc_id)
if str(e).find("Timeout") > 0:
continue
raise e
es_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def updateByQuery(self, q, d):
ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
scripts = ""
for k, v in d.items():
scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
ubq = ubq.script(source=scripts, params=d)
ubq = ubq.params(refresh=False)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
r = ubq.execute()
return True
except Exception as e:
es_logger.error("ES updateByQuery exception: " +
str(e) + "【Q】" + str(q.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
self.conn()
return False
def updateScriptByQuery(self, q, scripts, idxnm=None):
ubq = UpdateByQuery(
index=self.idxnm if not idxnm else idxnm).using(
self.es).query(q)
ubq = ubq.script(source=scripts)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
r = ubq.execute()
return True
except Exception as e:
es_logger.error("ES updateByQuery exception: " +
str(e) + "【Q】" + str(q.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
self.conn()
return False
def deleteByQuery(self, query, idxnm=""):
for i in range(3):
try:
r = self.es.delete_by_query(
index=idxnm if idxnm else self.idxnm,
refresh = True,
body=Search().query(query).to_dict())
return True
except Exception as e:
es_logger.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】" + str(query.to_dict()))
if str(e).find("NotFoundError") > 0: return True
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False
def update(self, id, script, routing=None):
for i in range(3):
try:
if not self.version():
r = self.es.update(
index=self.idxnm,
id=id,
body=json.dumps(
script,
ensure_ascii=False),
doc_type="doc",
routing=routing,
refresh=False)
else:
r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
routing=routing, refresh=False) # , doc_type="_doc")
return True
except Exception as e:
es_logger.error(
"ES update exception: " + str(e) + " id" + str(id) + ", version:" + str(self.version()) +
json.dumps(script, ensure_ascii=False))
if str(e).find("Timeout") > 0:
continue
return False
def indexExist(self, idxnm):
s = Index(idxnm if idxnm else self.idxnm, self.es)
for i in range(3): for i in range(3):
try: try:
return s.exists() return s.exists()
except Exception as e: except Exception as e:
es_logger.error("ES updateByQuery indexExist: " + str(e)) doc_store_logger.error("ES indexExist: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
return False return False
def docExist(self, docid, idxnm=None): """
CRUD operations
"""
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
assert "_id" not in condition
s = Search()
bqry = None
vector_similarity_weight = 0.5
for m in matchExprs:
if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = float(weights.split(",")[1])
for m in matchExprs:
if isinstance(m, MatchTextExpr):
minimum_should_match = "0%"
if "minimum_should_match" in m.extra_options:
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
bqry = Q("bool",
must=Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match = minimum_should_match,
boost=1),
boost = 1.0 - vector_similarity_weight,
)
if condition:
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
elif isinstance(m, MatchDenseExpr):
assert(bqry is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector = list(m.embedding_data),
filter = bqry.to_dict(),
similarity = similarity,
)
if matchExprs:
s.query = bqry
for field in highlightFields:
s = s.highlight(field)
if orderBy:
orders = list()
for field, order in orderBy.fields:
order = "asc" if order == 0 else "desc"
orders.append({field: {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}})
s = s.sort(*orders)
if limit > 0:
s = s[offset:limit]
q = s.to_dict()
doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q))
for i in range(3): for i in range(3):
try: try:
return self.es.exists(index=(idxnm if idxnm else self.idxnm), res = self.es.search(index=indexNames,
id=docid) body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
doc_store_logger.info("ESConnection.search res: " + str(res))
return res
except Exception as e: except Exception as e:
es_logger.error("ES Doc Exist: " + str(e)) doc_store_logger.error(
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: "ES search exception: " +
str(e) +
"\n[Q]: " +
str(q))
if str(e).find("Timeout") > 0:
continue continue
raise e
doc_store_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
for i in range(3):
try:
res = self.es.get(index=(indexName),
id=chunkId, source=True,)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
if not res.get("found"):
return None
chunk = res["_source"]
chunk["id"] = chunkId
return chunk
except Exception as e:
doc_store_logger.error(
"ES get exception: " +
str(e) +
"[Q]: " +
chunkId)
if str(e).find("Timeout") > 0:
continue
raise e
doc_store_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
meta_id = d_copy["id"]
del d_copy["id"]
operations.append(
{"index": {"_index": indexName, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(100):
try:
r = self.es.bulk(index=(indexName), operations=operations,
refresh=False, timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except Exception as e:
doc_store_logger.warning("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
doc = copy.deepcopy(newValue)
del doc['id']
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunkId = condition["id"]
for i in range(3):
try:
self.es.update(index=indexName, id=chunkId, doc=doc)
return True
except Exception as e:
doc_store_logger.error(
"ES update exception: " + str(e) + " id:" + str(id) +
json.dumps(newValue, ensure_ascii=False))
if str(e).find("Timeout") > 0:
continue
else:
# update unspecific maybe-multiple documents
bqry = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
for k, v in newValue.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, str):
scripts.append(f"ctx._source.{k} = '{v}'")
elif isinstance(v, int):
scripts.append(f"ctx._source.{k} = {v}")
else:
raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
ubq = ubq.script(source="; ".join(scripts))
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
_ = ubq.execute()
return True
except Exception as e:
doc_store_logger.error("ES update exception: " +
str(e) + "[Q]:" + str(bqry.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False return False
def createIdx(self, idxnm, mapping): def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
try: qry = None
if elasticsearch.__version__[0] < 8: assert "_id" not in condition
return self.es.indices.create(idxnm, body=mapping) if "id" in condition:
from elasticsearch.client import IndicesClient chunk_ids = condition["id"]
return IndicesClient(self.es).create(index=idxnm, if not isinstance(chunk_ids, list):
settings=mapping["settings"], chunk_ids = [chunk_ids]
mappings=mapping["mappings"]) qry = Q("ids", values=chunk_ids)
except Exception as e: else:
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e))) qry = Q("bool")
for k, v in condition.items():
if isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
for _ in range(10):
try:
res = self.es.delete_by_query(
index=indexName,
body = Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except Exception as e:
doc_store_logger.warning("Fail to delete: " + str(filter) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
def deleteIdx(self, idxnm):
try:
return self.es.indices.delete(idxnm, allow_no_indices=True)
except Exception as e:
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
"""
Helper functions for search result
"""
def getTotal(self, res): def getTotal(self, res):
if isinstance(res["hits"]["total"], type({})): if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"] return res["hits"]["total"]["value"]
return res["hits"]["total"] return res["hits"]["total"]
def getDocIds(self, res): def getChunkIds(self, res):
return [d["_id"] for d in res["hits"]["hits"]] return [d["_id"] for d in res["hits"]["hits"]]
def getSource(self, res): def __getSource(self, res):
rr = [] rr = []
for d in res["hits"]["hits"]: for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"] d["_source"]["id"] = d["_id"]
@ -425,40 +352,89 @@ class ESConnection:
rr.append(d["_source"]) rr.append(d["_source"])
return rr return rr
def scrollIter(self, pagesize=100, scroll_time='2m', q={ def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}): res_fields = {}
for _ in range(100): if not fields:
return {}
for d in self.__getSource(res):
m = {n: d.get(n) for n in fields if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, list):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(m[n])
# if n.find("tks") > 0:
# m[n] = rmSpace(m[n])
if m:
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: List[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split(" ")):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def getAggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
bkts = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
doc_store_logger.info(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
doc_store_logger.info(f"ESConnection.sql to es: {sql}")
for i in range(3):
try: try:
page = self.es.search( res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
index=self.idxnm, return res
scroll=scroll_time, except ConnectionTimeout:
size=pagesize, doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql)
body=q, continue
_source=None
)
break
except Exception as e: except Exception as e:
es_logger.error("ES scrolling fail. " + str(e)) doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e))
time.sleep(3) return None
doc_store_logger.error("ESConnection.sql timeout for 3 times!")
sid = page['_scroll_id'] return None
scroll_size = page['hits']['total']["value"]
es_logger.info("[TOTAL]%d" % scroll_size)
# Start scrolling
while scroll_size > 0:
yield page["hits"]["hits"]
for _ in range(100):
try:
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
break
except Exception as e:
es_logger.error("ES scrolling fail. " + str(e))
time.sleep(3)
# Update the scroll ID
sid = page['_scroll_id']
# Get the number of results that we returned in the last scroll
scroll_size = len(page['hits']['hits'])
ELASTICSEARCH = ESConnection()

436
rag/utils/infinity_conn.py Normal file
View File

@ -0,0 +1,436 @@
import os
import re
import json
from typing import List, Dict
import infinity
from infinity.common import ConflictType, InfinityException
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from rag import settings
from rag.settings import doc_store_logger
from rag.utils import singleton
import polars as pl
from polars.series.series import Series
from api.utils.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import (
DocStoreConnection,
MatchExpr,
MatchTextExpr,
MatchDenseExpr,
FusionExpr,
OrderByExpr,
)
def equivalent_condition_to_str(condition: dict) -> str:
assert "_id" not in condition
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond)
@singleton
class InfinityConnection(DocStoreConnection):
def __init__(self):
self.dbName = settings.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = ConnectionPool(infinity_uri)
doc_store_logger.info(f"Connected to infinity {infinity_uri}.")
"""
Database operations
"""
def dbType(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
"""
inf_conn = self.connPool.get_conn()
res = infinity.show_current_node()
self.connPool.release_conn(inf_conn)
color = "green" if res.error_code == 0 else "red"
res2 = {
"type": "infinity",
"status": f"{res.role} {color}",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(
get_project_base_directory(), "conf", "infinity_mapping.json"
)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vectorSize}_vec"
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
text_suffix = ["_tks", "_ltks", "_kwd"]
for field_name, field_info in schema.items():
if field_info["type"] != "varchar":
continue
for suffix in text_suffix:
if field_name.endswith(suffix):
inf_table.create_index(
f"text_idx_{field_name}",
IndexInfo(
field_name, IndexType.FullText, {"ANALYZER": "standard"}
),
ConflictType.Ignore,
)
break
self.connPool.release_conn(inf_conn)
doc_store_logger.info(
f"INFINITY created table {table_name}, vector size {vectorSize}"
)
def deleteIdx(self, indexName: str, knowledgebaseId: str):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
doc_store_logger.info(f"INFINITY dropped table {table_name}")
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
table_name = f"{indexName}_{knowledgebaseId}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
doc_store_logger.error("INFINITY indexExist: " + str(e))
return False
"""
CRUD operations
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
) -> list[dict] | pl.DataFrame:
"""
TODO: Infinity doesn't provide highlight
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
if "id" not in selectFields:
selectFields.append("id")
# Prepare expressions common to all tables
filter_cond = ""
filter_fulltext = ""
if condition:
filter_cond = equivalent_condition_to_str(condition)
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
fields = ",".join(matchExpr.fields)
filter_fulltext = (
f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
)
if len(filter_cond) != 0:
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
# doc_store_logger.info(f"filter_fulltext: {filter_fulltext}")
minimum_should_match = "0%"
if "minimum_should_match" in matchExpr.extra_options:
minimum_should_match = (
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%"
)
matchExpr.extra_options.update(
{"minimum_should_match": minimum_should_match}
)
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
elif isinstance(matchExpr, MatchDenseExpr):
if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
if orderBy.fields:
order_by_expr_list = list()
for order_field in orderBy.fields:
order_by_expr_list.append((order_field[0], order_field[1] == 0))
# Scatter search tables and gather the results
for indexName in indexNames:
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
continue
table_list.append(table_name)
builder = table_instance.output(selectFields)
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
if orderBy.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
kb_res = builder.to_pl()
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = pl.concat(df_list)
doc_store_logger.info("INFINITY search tables: " + str(table_list))
return res
def get(
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(knowledgebaseIds, list)
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
table_instance = db_instance.get_table(table_name)
kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = pl.concat(df_list)
res_fields = self.getFields(res, res.columns)
return res_fields.get(chunkId, None)
def insert(
self, documents: list[dict], indexName: str, knowledgebaseId: str
) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
# src/common/status.cppm, kTableNotExist = 3022
if e.error_code != 3022:
raise
vector_size = 0
patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
for k in documents[0].keys():
m = patt.match(k)
if m:
vector_size = int(m.group("vector_size"))
break
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.createIdx(indexName, knowledgebaseId, vector_size)
table_instance = db_instance.get_table(table_name)
for d in documents:
assert "_id" not in d
assert "id" in d
for k, v in d.items():
if k.endswith("_kwd") and isinstance(v, list):
d[k] = " ".join(v)
ids = [f"'{d["id"]}'" for d in documents]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter)
# for doc in documents:
# doc_store_logger.info(f"insert position_list: {doc['position_list']}")
# doc_store_logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(documents)
self.connPool.release_conn(inf_conn)
doc_store_logger.info(f"inserted into {table_name} {str_ids}.")
return []
def update(
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
) -> bool:
# if 'position_list' in newValue:
# doc_store_logger.info(f"update position_list: {newValue['position_list']}")
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_instance = db_instance.get_table(table_name)
filter = equivalent_condition_to_str(condition)
for k, v in newValue.items():
if k.endswith("_kwd") and isinstance(v, list):
newValue[k] = " ".join(v)
table_instance.update(filter, newValue)
self.connPool.release_conn(inf_conn)
return True
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
filter = equivalent_condition_to_str(condition)
try:
table_instance = db_instance.get_table(table_name)
except Exception:
doc_store_logger.warning(
f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
)
return 0
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def getTotal(self, res):
return len(res)
def getChunkIds(self, res):
return list(res["id"])
def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
res_fields = {}
if not fields:
return {}
num_rows = len(res)
column_id = res["id"]
for i in range(num_rows):
id = column_id[i]
m = {"id": id}
for fieldnm in fields:
if fieldnm not in res:
m[fieldnm] = None
continue
v = res[fieldnm][i]
if isinstance(v, Series):
v = list(v)
elif fieldnm == "important_kwd":
assert isinstance(v, str)
v = v.split(" ")
else:
if not isinstance(v, str):
v = str(v)
# if fieldnm.endswith("_tks"):
# v = rmSpace(v)
m[fieldnm] = v
res_fields[id] = m
return res_fields
def getHighlight(self, res, keywords: List[str], fieldnm: str):
ans = {}
num_rows = len(res)
column_id = res["id"]
for i in range(num_rows):
id = column_id[i]
txt = res[fieldnm][i]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
% re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
):
continue
txts.append(t)
ans[id] = "...".join(txts)
return ans
def getAggregation(self, res, fieldnm: str):
"""
TODO: Infinity doesn't provide aggregation
"""
return list()
"""
SQL
"""
def sql(sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -50,8 +50,8 @@ class Document(Base):
return res.content return res.content
def list_chunks(self,page=1, page_size=30, keywords="", id:str=None): def list_chunks(self,page=1, page_size=30, keywords=""):
data={"keywords": keywords,"page":page,"page_size":page_size,"id":id} data={"keywords": keywords,"page":page,"page_size":page_size}
res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data) res = self.get(f'/datasets/{self.dataset_id}/documents/{self.id}/chunks', data)
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:

View File

@ -126,6 +126,7 @@ def test_delete_chunk_with_success(get_api_key_fixture):
docs = ds.upload_documents(documents) docs = ds.upload_documents(documents)
doc = docs[0] doc = docs[0]
chunk = doc.add_chunk(content="This is a chunk addition test") chunk = doc.add_chunk(content="This is a chunk addition test")
sleep(5)
doc.delete_chunks([chunk.id]) doc.delete_chunks([chunk.id])
@ -146,6 +147,8 @@ def test_update_chunk_content(get_api_key_fixture):
docs = ds.upload_documents(documents) docs = ds.upload_documents(documents)
doc = docs[0] doc = docs[0]
chunk = doc.add_chunk(content="This is a chunk addition test") chunk = doc.add_chunk(content="This is a chunk addition test")
# For ElasticSearch, the chunk is not searchable in shot time (~2s).
sleep(3)
chunk.update({"content":"This is a updated content"}) chunk.update({"content":"This is a updated content"})
def test_update_chunk_available(get_api_key_fixture): def test_update_chunk_available(get_api_key_fixture):
@ -165,7 +168,9 @@ def test_update_chunk_available(get_api_key_fixture):
docs = ds.upload_documents(documents) docs = ds.upload_documents(documents)
doc = docs[0] doc = docs[0]
chunk = doc.add_chunk(content="This is a chunk addition test") chunk = doc.add_chunk(content="This is a chunk addition test")
chunk.update({"available":False}) # For ElasticSearch, the chunk is not searchable in shot time (~2s).
sleep(3)
chunk.update({"available":0})
def test_retrieve_chunks(get_api_key_fixture): def test_retrieve_chunks(get_api_key_fixture):