update document sdk (#2445)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
JobSmithManipulation 2024-09-18 11:08:19 +08:00 committed by GitHub
parent e7dd487779
commit 62cb5f1bac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 348 additions and 58 deletions

View File

@ -84,15 +84,28 @@ def upload(dataset_id, tenant_id):
@token_required @token_required
def docinfos(tenant_id): def docinfos(tenant_id):
req = request.args req = request.args
if "id" not in req and "name" not in req:
return get_data_error_result(
retmsg="Id or name should be provided")
doc_id=None
if "id" in req: if "id" in req:
doc_id = req["id"] doc_id = req["id"]
e, doc = DocumentService.get_by_id(doc_id)
return get_json_result(data=doc.to_json())
if "name" in req: if "name" in req:
doc_name = req["name"] doc_name = req["name"]
doc_id = DocumentService.get_doc_id_by_doc_name(doc_name) doc_id = DocumentService.get_doc_id_by_doc_name(doc_name)
e, doc = DocumentService.get_by_id(doc_id) e, doc = DocumentService.get_by_id(doc_id)
return get_json_result(data=doc.to_json()) #rename key's name
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "knowledgebase_id",
"token_num": "token_count",
}
renamed_doc = {}
for key, value in doc.to_dict().items():
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
return get_json_result(data=renamed_doc)
@manager.route('/save', methods=['POST']) @manager.route('/save', methods=['POST'])
@ -246,7 +259,7 @@ def rename():
req["doc_id"], {"name": req["name"]}): req["doc_id"], {"name": req["name"]}):
return get_data_error_result( return get_data_error_result(
retmsg="Database error (Document rename)!") retmsg="Database error (Document rename)!")
informs = File2DocumentService.get_by_document_id(req["doc_id"]) informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs: if informs:
e, file = FileService.get_by_id(informs[0].file_id) e, file = FileService.get_by_id(informs[0].file_id)
@ -259,7 +272,7 @@ def rename():
@manager.route("/<document_id>", methods=["GET"]) @manager.route("/<document_id>", methods=["GET"])
@token_required @token_required
def download_document(dataset_id, document_id): def download_document(dataset_id, document_id,tenant_id):
try: try:
# Check whether there is this document # Check whether there is this document
exist, document = DocumentService.get_by_id(document_id) exist, document = DocumentService.get_by_id(document_id)
@ -313,7 +326,21 @@ def list_docs(dataset_id, tenant_id):
try: try:
docs, tol = DocumentService.get_by_kb_id( docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords) kb_id, page_number, items_per_page, orderby, desc, keywords)
return get_json_result(data={"total": tol, "docs": docs})
# rename key's name
renamed_doc_list = []
for doc in docs:
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "knowledgebase_id",
"token_num": "token_count",
}
renamed_doc = {}
for key, value in doc.items():
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
renamed_doc_list.append(renamed_doc)
return get_json_result(data={"total": tol, "docs": renamed_doc_list})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -436,6 +463,8 @@ def list_chunk(tenant_id):
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True) sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
origin_chunks=[]
for id in sres.ids: for id in sres.ids:
d = { d = {
"chunk_id": id, "chunk_id": id,
@ -455,7 +484,21 @@ def list_chunk(tenant_id):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]), poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])]) float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss d["positions"] = poss
res["chunks"].append(d)
origin_chunks.append(d)
##rename keys
for chunk in origin_chunks:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
}
renamed_chunk = {}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
res["chunks"].append(renamed_chunk)
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
@ -471,8 +514,9 @@ def create(tenant_id):
req = request.json req = request.json
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), chunk_id = md5.hexdigest()
d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]} "content_with_weight": req["content_with_weight"]}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", []) d["important_kwd"] = req.get("important_kwd", [])
@ -503,20 +547,33 @@ def create(tenant_id):
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk": d}) d["chunk_id"] = chunk_id
# return get_json_result(data={"chunk_id": chunck_id}) #rename keys
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
"kb_id":"knowledge_base_id",
}
renamed_chunk = {}
for key, value in d.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
return get_json_result(data={"chunk": renamed_chunk})
# return get_json_result(data={"chunk_id": chunk_id})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/chunk/rm', methods=['POST']) @manager.route('/chunk/rm', methods=['POST'])
@token_required @token_required
@validate_request("chunk_ids", "doc_id") @validate_request("chunk_ids", "doc_id")
def rm_chunk(): def rm_chunk(tenant_id):
req = request.json req = request.json
try: try:
if not ELASTICSEARCH.deleteByQuery( if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)):
return get_data_error_result(retmsg="Index updating failure") return get_data_error_result(retmsg="Index updating failure")
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
@ -526,4 +583,126 @@ def rm_chunk():
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e)
@manager.route('/chunk/set', methods=['POST'])
@token_required
@validate_request("doc_id", "chunk_id", "content_with_weight",
"important_kwd")
def set(tenant_id):
req = request.json
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
if "available_int" in req:
d["available_int"] = req["available_int"]
try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id == ParserType.QA:
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
if len(arr) != 2:
return get_data_error_result(
retmsg="Q&A must be separated by TAB/ENTER key.")
q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
d = beAdoc(d, arr[0], arr[1], not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/retrieval_test', methods=['POST'])
@token_required
@validate_request("kb_id", "question")
def retrieval_test(tenant_id):
req = request.json
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
kb_id = req["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id]
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
try:
tenants = UserTenantService.query(user_id=tenant_id)
for kid in kb_id:
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kid):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
##rename keys
renamed_chunks=[]
for chunk in ranks["chunks"]:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
}
rename_chunk={}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
rename_chunk[new_key] = value
renamed_chunks.append(rename_chunk)
ranks["chunks"] = renamed_chunks
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)

View File

@ -3,32 +3,48 @@ from .base import Base
class Chunk(Base): class Chunk(Base):
def __init__(self, rag, res_dict): def __init__(self, rag, res_dict):
# 初始化类的属性
self.id = "" self.id = ""
self.content_with_weight = "" self.content = ""
self.content_ltks = [] self.important_keywords = []
self.content_sm_ltks = []
self.important_kwd = []
self.important_tks = []
self.create_time = "" self.create_time = ""
self.create_timestamp_flt = 0.0 self.create_timestamp_flt = 0.0
self.kb_id = None self.knowledgebase_id = None
self.docnm_kwd = "" self.document_name = ""
self.doc_id = "" self.document_id = ""
self.q_vec = []
self.status = "1" self.status = "1"
for k, v in res_dict.items(): for k in list(res_dict.keys()):
if hasattr(self, k): if k not in self.__dict__:
setattr(self, k, v) res_dict.pop(k)
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def delete(self) -> bool: def delete(self) -> bool:
""" """
Delete the chunk in the document. Delete the chunk in the document.
""" """
res = self.rm('/doc/chunk/rm', res = self.post('/doc/chunk/rm',
{"doc_id": [self.id],""}) {"doc_id": self.document_id, 'chunk_ids': [self.id]})
res = res.json() res = res.json()
if res.get("retmsg") == "success": if res.get("retmsg") == "success":
return True return True
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def save(self) -> bool:
"""
Save the document details to the server.
"""
res = self.post('/doc/chunk/set',
{"chunk_id": self.id,
"kb_id": self.knowledgebase_id,
"name": self.document_name,
"content_with_weight": self.content,
"important_kwd": self.important_keywords,
"create_time": self.create_time,
"create_timestamp_flt": self.create_timestamp_flt,
"doc_id": self.document_id,
"status": self.status,
})
res = res.json()
if res.get("retmsg") == "success":
return True
raise Exception(res["retmsg"])

View File

@ -9,15 +9,15 @@ class Document(Base):
self.id = "" self.id = ""
self.name = "" self.name = ""
self.thumbnail = None self.thumbnail = None
self.kb_id = None self.knowledgebase_id = None
self.parser_method = "" self.parser_method = ""
self.parser_config = {"pages": [[1, 1000000]]} self.parser_config = {"pages": [[1, 1000000]]}
self.source_type = "local" self.source_type = "local"
self.type = "" self.type = ""
self.created_by = "" self.created_by = ""
self.size = 0 self.size = 0
self.token_num = 0 self.token_count = 0
self.chunk_num = 0 self.chunk_count = 0
self.progress = 0.0 self.progress = 0.0
self.progress_msg = "" self.progress_msg = ""
self.process_begin_at = None self.process_begin_at = None
@ -34,10 +34,10 @@ class Document(Base):
Save the document details to the server. Save the document details to the server.
""" """
res = self.post('/doc/save', res = self.post('/doc/save',
{"id": self.id, "name": self.name, "thumbnail": self.thumbnail, "kb_id": self.kb_id, {"id": self.id, "name": self.name, "thumbnail": self.thumbnail, "kb_id": self.knowledgebase_id,
"parser_id": self.parser_method, "parser_config": self.parser_config.to_json(), "parser_id": self.parser_method, "parser_config": self.parser_config.to_json(),
"source_type": self.source_type, "type": self.type, "created_by": self.created_by, "source_type": self.source_type, "type": self.type, "created_by": self.created_by,
"size": self.size, "token_num": self.token_num, "chunk_num": self.chunk_num, "size": self.size, "token_num": self.token_count, "chunk_num": self.chunk_count,
"progress": self.progress, "progress_msg": self.progress_msg, "progress": self.progress, "progress_msg": self.progress_msg,
"process_begin_at": self.process_begin_at, "process_duation": self.process_duration "process_begin_at": self.process_begin_at, "process_duation": self.process_duration
}) })
@ -177,8 +177,10 @@ class Document(Base):
if res.status_code == 200: if res.status_code == 200:
res_data = res.json() res_data = res.json()
if res_data.get("retmsg") == "success": if res_data.get("retmsg") == "success":
chunks = res_data["data"]["chunks"] chunks=[]
self.chunks = chunks # Store the chunks in the document instance for chunk_data in res_data["data"].get("chunks", []):
chunk=Chunk(self.rag,chunk_data)
chunks.append(chunk)
return chunks return chunks
else: else:
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}") raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
@ -187,10 +189,9 @@ class Document(Base):
def add_chunk(self, content: str): def add_chunk(self, content: str):
res = self.post('/doc/chunk/create', {"doc_id": self.id, "content_with_weight":content}) res = self.post('/doc/chunk/create', {"doc_id": self.id, "content_with_weight":content})
# 假设返回的 response 包含 chunk 的信息
if res.status_code == 200: if res.status_code == 200:
chunk_data = res.json() res_data = res.json().get("data")
return Chunk(self.rag,chunk_data) # 假设有一个 Chunk 类来处理 chunk 对象 chunk_data = res_data.get("chunk")
return Chunk(self.rag,chunk_data)
else: else:
raise Exception(f"Failed to add chunk: {res.status_code} {res.text}") raise Exception(f"Failed to add chunk: {res.status_code} {res.text}")

View File

@ -20,6 +20,8 @@ import requests
from .modules.assistant import Assistant from .modules.assistant import Assistant
from .modules.dataset import DataSet from .modules.dataset import DataSet
from .modules.document import Document from .modules.document import Document
from .modules.chunk import Chunk
class RAGFlow: class RAGFlow:
def __init__(self, user_key, base_url, version='v1'): def __init__(self, user_key, base_url, version='v1'):
@ -143,7 +145,7 @@ class RAGFlow:
return result_list return result_list
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def create_document(self, ds:DataSet, name: str, blob: bytes) -> bool: def create_document(self, ds: DataSet, name: str, blob: bytes) -> bool:
url = f"/doc/dataset/{ds.id}/documents/upload" url = f"/doc/dataset/{ds.id}/documents/upload"
files = { files = {
'file': (name, blob) 'file': (name, blob)
@ -164,6 +166,7 @@ class RAGFlow:
raise Exception(f"Upload failed: {response.json().get('retmsg')}") raise Exception(f"Upload failed: {response.json().get('retmsg')}")
return False return False
def get_document(self, id: str = None, name: str = None) -> Document: def get_document(self, id: str = None, name: str = None) -> Document:
res = self.get("/doc/infos", {"id": id, "name": name}) res = self.get("/doc/infos", {"id": id, "name": name})
res = res.json() res = res.json()
@ -204,8 +207,6 @@ class RAGFlow:
if not doc_ids or not isinstance(doc_ids, list): if not doc_ids or not isinstance(doc_ids, list):
raise ValueError("doc_ids must be a non-empty list of document IDs") raise ValueError("doc_ids must be a non-empty list of document IDs")
data = {"doc_ids": doc_ids, "run": 2} data = {"doc_ids": doc_ids, "run": 2}
res = self.post(f'/doc/run', data) res = self.post(f'/doc/run', data)
if res.status_code != 200: if res.status_code != 200:
@ -217,4 +218,61 @@ class RAGFlow:
print(f"Error occurred during canceling parsing for documents: {str(e)}") print(f"Error occurred during canceling parsing for documents: {str(e)}")
raise raise
def retrieval(self,
question,
datasets=None,
documents=None,
offset=0,
limit=6,
similarity_threshold=0.1,
vector_similarity_weight=0.3,
top_k=1024):
"""
Perform document retrieval based on the given parameters.
:param question: The query question.
:param datasets: A list of datasets (optional, as documents may be provided directly).
:param documents: A list of documents (if specific documents are provided).
:param offset: Offset for the retrieval results.
:param limit: Maximum number of retrieval results.
:param similarity_threshold: Similarity threshold.
:param vector_similarity_weight: Weight of vector similarity.
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API.
"""
try:
data = {
"question": question,
"datasets": datasets if datasets is not None else [],
"documents": [doc.id if hasattr(doc, 'id') else doc for doc in
documents] if documents is not None else [],
"offset": offset,
"limit": limit,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"kb_id": datasets,
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post(f'/doc/retrieval_test', data)
# Check the response status code
if res.status_code == 200:
res_data = res.json()
if res_data.get("retmsg") == "success":
chunks = []
for chunk_data in res_data["data"].get("chunks", []):
chunk = Chunk(self, chunk_data)
chunks.append(chunk)
return chunks
else:
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
else:
raise Exception(f"API request failed with status code {res.status_code}")
except Exception as e:
print(f"An error occurred during retrieval: {e}")
raise

View File

@ -41,6 +41,7 @@ class TestDocument(TestSdk):
def test_update_document_with_success(self): def test_update_document_with_success(self):
""" """
Test updating a document with success. Test updating a document with success.
Update name or parser_method are supported
""" """
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
doc = rag.get_document(name="TestDocument.txt") doc = rag.get_document(name="TestDocument.txt")
@ -60,7 +61,7 @@ class TestDocument(TestSdk):
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
# Retrieve a document # Retrieve a document
doc = rag.get_document(name="TestDocument.txt") doc = rag.get_document(name="manual.txt")
# Check if the retrieved document is of type Document # Check if the retrieved document is of type Document
if isinstance(doc, Document): if isinstance(doc, Document):
@ -147,14 +148,16 @@ class TestDocument(TestSdk):
ds = rag.create_dataset(name="God4") ds = rag.create_dataset(name="God4")
# Define the document name and path # Define the document name and path
name3 = 'ai.pdf' name3 = 'westworld.pdf'
path = 'test_data/ai.pdf' path = 'test_data/westworld.pdf'
# Create a document in the dataset using the file path # Create a document in the dataset using the file path
rag.create_document(ds, name=name3, blob=open(path, "rb").read()) rag.create_document(ds, name=name3, blob=open(path, "rb").read())
# Retrieve the document by name # Retrieve the document by name
doc = rag.get_document(name="ai.pdf") doc = rag.get_document(name="westworld.pdf")
# Initiate asynchronous parsing # Initiate asynchronous parsing
doc.async_parse() doc.async_parse()
@ -185,9 +188,9 @@ class TestDocument(TestSdk):
# Prepare a list of file names and paths # Prepare a list of file names and paths
documents = [ documents = [
{'name': 'ai1.pdf', 'path': 'test_data/ai1.pdf'}, {'name': 'test1.txt', 'path': 'test_data/test1.txt'},
{'name': 'ai2.pdf', 'path': 'test_data/ai2.pdf'}, {'name': 'test2.txt', 'path': 'test_data/test2.txt'},
{'name': 'ai3.pdf', 'path': 'test_data/ai3.pdf'} {'name': 'test3.txt', 'path': 'test_data/test3.txt'}
] ]
# Create documents in bulk # Create documents in bulk
@ -248,6 +251,7 @@ class TestDocument(TestSdk):
print(c) print(c)
assert c is not None, "Chunk is None" assert c is not None, "Chunk is None"
assert "rag" in c['content_with_weight'].lower(), f"Keyword 'rag' not found in chunk content: {c.content}" assert "rag" in c['content_with_weight'].lower(), f"Keyword 'rag' not found in chunk content: {c.content}"
def test_add_chunk_to_chunk_list(self): def test_add_chunk_to_chunk_list(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
doc = rag.get_document(name='story.txt') doc = rag.get_document(name='story.txt')
@ -258,12 +262,44 @@ class TestDocument(TestSdk):
def test_delete_chunk_of_chunk_list(self): def test_delete_chunk_of_chunk_list(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
doc = rag.get_document(name='story.txt') doc = rag.get_document(name='story.txt')
chunk = doc.add_chunk(content="assss") chunk = doc.add_chunk(content="assss")
assert chunk is not None, "Chunk is None" assert chunk is not None, "Chunk is None"
assert isinstance(chunk, Chunk), "Chunk was not added to chunk list" assert isinstance(chunk, Chunk), "Chunk was not added to chunk list"
chunk_num_before=doc.chunk_num doc = rag.get_document(name='story.txt')
chunk_count_before=doc.chunk_count
chunk.delete() chunk.delete()
assert doc.chunk_num == chunk_num_before-1, "Chunk was not deleted" doc = rag.get_document(name='story.txt')
assert doc.chunk_count == chunk_count_before-1, "Chunk was not deleted"
def test_update_chunk_content(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS)
doc = rag.get_document(name='story.txt')
chunk = doc.add_chunk(content="assssd")
assert chunk is not None, "Chunk is None"
assert isinstance(chunk, Chunk), "Chunk was not added to chunk list"
chunk.content = "ragflow123"
res=chunk.save()
assert res is True, f"Failed to update chunk, error: {res}"
def test_retrieval_chunks(self):
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset(name="God8")
name = 'ragflow_test.txt'
path = 'test_data/ragflow_test.txt'
rag.create_document(ds, name=name, blob=open(path, "rb").read())
doc = rag.get_document(name=name)
doc.async_parse()
# Wait for parsing to complete and get progress updates using join
for progress, msg in doc.join(interval=5, timeout=30):
print(progress, msg)
assert 0 <= progress <= 100, f"Invalid progress: {progress}"
assert msg, "Message should not be empty"
for c in rag.retrieval(question="What's ragflow?",
datasets=[ds.id], documents=[doc],
offset=0, limit=6, similarity_threshold=0.1,
vector_similarity_weight=0.3,
top_k=1024
):
print(c)
assert c is not None, "Chunk is None"
assert "ragflow" in c.content.lower(), f"Keyword 'rag' not found in chunk content: {c.content}"