Validate returned chunk at list_chunks and add_chunk (#4153)

### What problem does this PR solve?

Validate returned chunk at list_chunks and add_chunk

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu 2024-12-20 22:55:45 +08:00 committed by GitHub
parent 35580af875
commit 85083ad400
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 17 deletions

View File

@ -42,9 +42,30 @@ from rag.nlp import search
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from pydantic import BaseModel, Field, validator
MAXIMUM_OF_UPLOADING_FILES = 256 MAXIMUM_OF_UPLOADING_FILES = 256
class Chunk(BaseModel):
id: str = ""
content: str = ""
document_id: str = ""
docnm_kwd: str = ""
important_keywords: list = Field(default_factory=list)
questions: list = Field(default_factory=list)
question_tks: str = ""
image_id: str = ""
available: bool = True
positions: list[list[int]] = Field(default_factory=list)
@validator('positions')
def validate_positions(cls, value):
for sublist in value:
if len(sublist) != 5:
raise ValueError("Each sublist in positions must have a length of 5")
return value
@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821 @manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
@token_required @token_required
def upload(dataset_id, tenant_id): def upload(dataset_id, tenant_id):
@ -848,20 +869,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
"available_int": sres.field[id].get("available_int", 1), "available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", []), "positions": sres.field[id].get("position_int", []),
} }
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append(
[
float(d["positions"][i]),
float(d["positions"][i + 1]),
float(d["positions"][i + 2]),
float(d["positions"][i + 3]),
float(d["positions"][i + 4]),
]
)
d["positions"] = poss
origin_chunks.append(d) origin_chunks.append(d)
if req.get("id"): if req.get("id"):
if req.get("id") == id: if req.get("id") == id:
@ -892,6 +899,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
if renamed_chunk["available"] == 1: if renamed_chunk["available"] == 1:
renamed_chunk["available"] = True renamed_chunk["available"] = True
res["chunks"].append(renamed_chunk) res["chunks"].append(renamed_chunk)
_ = Chunk(**renamed_chunk) # validate the chunk
return get_result(data=res) return get_result(data=res)
@ -1031,6 +1039,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
if key in key_mapping: if key in key_mapping:
new_key = key_mapping.get(key, key) new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value renamed_chunk[new_key] = value
_ = Chunk(**renamed_chunk) # validate the chunk
return get_result(data={"chunk": renamed_chunk}) return get_result(data={"chunk": renamed_chunk})
# return get_result(data={"chunk_id": chunk_id}) # return get_result(data={"chunk_id": chunk_id})

View File

@ -3,6 +3,7 @@ import os
import re import re
import json import json
import time import time
import copy
import infinity import infinity
from infinity.common import ConflictType, InfinityException, SortType from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType from infinity.index import IndexInfo, IndexType
@ -390,7 +391,8 @@ class InfinityConnection(DocStoreConnection):
self.createIdx(indexName, knowledgebaseId, vector_size) self.createIdx(indexName, knowledgebaseId, vector_size)
table_instance = db_instance.get_table(table_name) table_instance = db_instance.get_table(table_name)
for d in documents: docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d assert "_id" not in d
assert "id" in d assert "id" in d
for k, v in d.items(): for k, v in d.items():
@ -407,14 +409,14 @@ class InfinityConnection(DocStoreConnection):
elif k in ["page_num_int", "top_int"]: elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list) assert isinstance(v, list)
d[k] = "_".join(f"{num:08x}" for num in v) d[k] = "_".join(f"{num:08x}" for num in v)
ids = ["'{}'".format(d["id"]) for d in documents] ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids) str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})" str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter) table_instance.delete(str_filter)
# for doc in documents: # for doc in documents:
# logger.info(f"insert position_int: {doc['position_int']}") # logger.info(f"insert position_int: {doc['position_int']}")
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}") # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(documents) table_instance.insert(docs)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.") logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return [] return []