### What problem does this PR solve?

#4367

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2025-01-09 17:07:21 +08:00 committed by GitHub
parent f892d7d426
commit c5da3cdd97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 736 additions and 202 deletions

View File

@ -19,6 +19,7 @@ from abc import ABC
import pandas as pd import pandas as pd
from api.db import LLMType from api.db import LLMType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api import settings from api import settings
@ -70,7 +71,8 @@ class Retrieval(ComponentBase, ABC):
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
1, self._param.top_n, 1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl) aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs))
if not kbinfos["chunks"]: if not kbinfos["chunks"]:
df = Retrieval.be_output("") df = Retrieval.be_output("")

View File

@ -25,7 +25,7 @@ from api.db import FileType, LLMType, ParserType, FileSource
from api.db.db_models import APIToken, Task, File from api.db.db_models import APIToken, Task, File
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat, keyword_extraction from api.db.services.dialog_service import DialogService, chat, keyword_extraction, label_question
from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
@ -840,7 +840,8 @@ def retrieval():
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl) doc_ids, rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs))
for c in ranks["chunks"]: for c in ranks["chunks"]:
c.pop("vector", None) c.pop("vector", None)
return get_json_result(data=ranks) return get_json_result(data=ranks)

View File

@ -19,9 +19,10 @@ import json
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.dialog_service import keyword_extraction from api.db.services.dialog_service import keyword_extraction, label_question
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace from rag.utils import rmSpace
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -124,10 +125,14 @@ def set():
"content_with_weight": req["content_with_weight"]} "content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req["important_kwd"] if req.get("important_kwd"):
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) d["important_kwd"] = req["important_kwd"]
d["question_kwd"] = req["question_kwd"] d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) if req.get("question_kwd"):
d["question_kwd"] = req["question_kwd"]
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
if req.get("tag_kwd"):
d["tag_kwd"] = req["tag_kwd"]
if "available_int" in req: if "available_int" in req:
d["available_int"] = req["available_int"] d["available_int"] = req["available_int"]
@ -220,7 +225,7 @@ def create():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
d["kb_id"] = doc.kb_id d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name) d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id d["doc_id"] = doc.id
@ -233,7 +238,7 @@ def create():
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank: if kb.pagerank:
d["pagerank_fea"] = kb.pagerank d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"]) embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
@ -294,12 +299,16 @@ def retrieval_test():
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
rank_feature=labels
)
for c in ranks["chunks"]: for c in ranks["chunks"]:
c.pop("vector", None) c.pop("vector", None)
ranks["labels"] = labels
return get_json_result(data=ranks) return get_json_result(data=ranks)
except Exception as e: except Exception as e:

View File

@ -25,7 +25,7 @@ from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db import LLMType from api.db import LLMType
from api.db.services.dialog_service import DialogService, chat, ask from api.db.services.dialog_service import DialogService, chat, ask, label_question
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api import settings from api import settings
@ -379,8 +379,11 @@ def mindmap():
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, question = req["question"]
0.3, 0.3, aggs=False) ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False,
rank_feature=label_question(question, [kb])
)
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
if "error" in mind_map: if "error" in mind_map:

View File

@ -30,6 +30,7 @@ from api.utils.api_utils import get_json_result
from api import settings from api import settings
from rag.nlp import search from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD
@manager.route('/create', methods=['post']) # noqa: F821 @manager.route('/create', methods=['post']) # noqa: F821
@ -104,11 +105,11 @@ def update():
if kb.pagerank != req.get("pagerank", 0): if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0: if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]}, settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
else: else:
# Elasticsearch requires pagerank_fea be non-zero! # Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"}, settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
e, kb = KnowledgebaseService.get_by_id(kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id)
@ -150,12 +151,14 @@ def list_kbs():
keywords = request.args.get("keywords", "") keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1)) page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 150)) items_per_page = int(request.args.get("page_size", 150))
parser_id = request.args.get("parser_id")
orderby = request.args.get("orderby", "create_time") orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True) desc = request.args.get("desc", True)
try: try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
kbs, total = KnowledgebaseService.get_by_tenant_ids( kbs, total = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords) [m["tenant_id"] for m in tenants], current_user.id, page_number,
items_per_page, orderby, desc, keywords, parser_id)
return get_json_result(data={"kbs": kbs, "total": total}) return get_json_result(data={"kbs": kbs, "total": total})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -199,3 +202,72 @@ def rm():
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/<kb_id>/tags', methods=['GET']) # noqa: F821
@login_required
def list_tags(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
tags = settings.retrievaler.all_tags(current_user.id, [kb_id])
return get_json_result(data=tags)
@manager.route('/tags', methods=['GET']) # noqa: F821
@login_required
def list_tags_from_kbs():
kb_ids = request.args.get("kb_ids", "").split(",")
for kb_id in kb_ids:
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
tags = settings.retrievaler.all_tags(current_user.id, kb_ids)
return get_json_result(data=tags)
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
@login_required
def rm_tags(kb_id):
req = request.json
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
for t in req["tags"]:
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
{"remove": {"tag_kwd": t}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
@login_required
def rename_tags(kb_id):
req = request.json
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)

View File

@ -73,7 +73,8 @@ def create(tenant_id):
chunk_method: chunk_method:
type: string type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
"presentation", "picture", "one", "knowledge_graph", "email"] "presentation", "picture", "one", "knowledge_graph", "email", "tag"
]
description: Chunking method. description: Chunking method.
parser_config: parser_config:
type: object type: object
@ -108,6 +109,7 @@ def create(tenant_id):
"one", "one",
"knowledge_graph", "knowledge_graph",
"email", "email",
"tag"
] ]
check_validation = valid( check_validation = valid(
permission, permission,
@ -302,7 +304,8 @@ def update(tenant_id, dataset_id):
chunk_method: chunk_method:
type: string type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws", enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
"presentation", "picture", "one", "knowledge_graph", "email"] "presentation", "picture", "one", "knowledge_graph", "email", "tag"
]
description: Updated chunking method. description: Updated chunking method.
parser_config: parser_config:
type: object type: object
@ -339,6 +342,7 @@ def update(tenant_id, dataset_id):
"one", "one",
"knowledge_graph", "knowledge_graph",
"email", "email",
"tag"
] ]
check_validation = valid( check_validation = valid(
permission, permission,

View File

@ -16,6 +16,7 @@
from flask import request, jsonify from flask import request, jsonify
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api import settings from api import settings
@ -54,7 +55,8 @@ def retrieval(tenant_id):
page_size=top, page_size=top,
similarity_threshold=similarity_threshold, similarity_threshold=similarity_threshold,
vector_similarity_weight=0.3, vector_similarity_weight=0.3,
top=top top=top,
rank_feature=label_question(question, [kb])
) )
records = [] records = []
for c in ranks["chunks"]: for c in ranks["chunks"]:

View File

@ -16,7 +16,7 @@
import pathlib import pathlib
import datetime import datetime
from api.db.services.dialog_service import keyword_extraction from api.db.services.dialog_service import keyword_extraction, label_question
from rag.app.qa import rmPrefix, beAdoc from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
@ -276,6 +276,7 @@ def update_doc(tenant_id, dataset_id, document_id):
"one", "one",
"knowledge_graph", "knowledge_graph",
"email", "email",
"tag"
} }
if req.get("chunk_method") not in valid_chunk_method: if req.get("chunk_method") not in valid_chunk_method:
return get_error_data_result( return get_error_data_result(
@ -1355,6 +1356,7 @@ def retrieval_test(tenant_id):
doc_ids, doc_ids,
rerank_mdl=rerank_mdl, rerank_mdl=rerank_mdl,
highlight=highlight, highlight=highlight,
rank_feature=label_question(question, kbs)
) )
for c in ranks["chunks"]: for c in ranks["chunks"]:
c.pop("vector", None) c.pop("vector", None)

View File

@ -89,6 +89,7 @@ class ParserType(StrEnum):
AUDIO = "audio" AUDIO = "audio"
EMAIL = "email" EMAIL = "email"
KG = "knowledge_graph" KG = "knowledge_graph"
TAG = "tag"
class FileSource(StrEnum): class FileSource(StrEnum):

View File

@ -133,7 +133,7 @@ def init_llm_factory():
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"}) TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
TenantService.filter_update([1 == 1], { TenantService.filter_update([1 == 1], {
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"}) "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
## insert openai two embedding models to the current openai user. ## insert openai two embedding models to the current openai user.
# print("Start to insert 2 OpenAI embedding models...") # print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()]) tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
@ -153,14 +153,7 @@ def init_llm_factory():
break break
for kb_id in KnowledgebaseService.get_all_ids(): for kb_id in KnowledgebaseService.get_all_ids():
KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)}) KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
"""
drop table llm;
drop table llm_factories;
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
alter table knowledgebase modify avatar longtext;
alter table user modify avatar longtext;
alter table dialog modify icon longtext;
"""
def add_graph_templates(): def add_graph_templates():

View File

@ -29,8 +29,10 @@ from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api import settings from api import settings
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.settings import TAG_FLD
from rag.utils import rmSpace, num_tokens_from_string, encoder from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
@ -135,6 +137,29 @@ def kb_prompt(kbinfos, max_tokens):
return knowledges return knowledges
def label_question(question, kbs):
tags = None
tag_kb_ids = []
for kb in kbs:
if kb.parser_config.get("tag_kb_ids"):
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(all_tags, tag_kb_ids)
else:
all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
tags = settings.retrievaler.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,
kb.parser_config.get("topn_tags", 3)
)
return tags
def chat(dialog, messages, stream=True, **kwargs): def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
@ -236,11 +261,14 @@ def chat(dialog, messages, stream=True, **kwargs):
generate_keyword_ts = timer() generate_keyword_ts = timer()
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold, dialog.similarity_threshold,
dialog.vector_similarity_weight, dialog.vector_similarity_weight,
doc_ids=attachments, doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
retrieval_ts = timer() retrieval_ts = timer()
@ -650,7 +678,10 @@ def ask(question, kb_ids, tenant_id):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
max_tokens = chat_mdl.max_length max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
1, 12, 0.1, 0.3, aggs=False,
rank_feature=label_question(question, kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens) knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """ prompt = """
Role: You're a smart assistant. Your name is Miss R. Role: You're a smart assistant. Your name is Miss R.
@ -700,3 +731,56 @@ def ask(question, kb_ids, tenant_id):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer) yield decorate_answer(answer)
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
Steps::
- Comprehend the tag/label set.
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
Requirements
- The tags MUST be from the tag set.
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
- The relevance score must be range from 1 to 10.
- Keywords ONLY in output.
# TAG SET
{", ".join(all_tags)}
"""
for i, ex in enumerate(examples):
prompt += """
# Examples {}
### Text Content
{}
Output:
{}
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
prompt += f"""
# Real Data
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
kwd = re.sub(r".*?\{", "{", kwd)
return json.loads(kwd)

View File

@ -43,10 +43,7 @@ class File2DocumentService(CommonService):
def insert(cls, obj): def insert(cls, obj):
if not cls.save(**obj): if not cls.save(**obj):
raise RuntimeError("Database error (File)!") raise RuntimeError("Database error (File)!")
e, obj = cls.get_by_id(obj["id"]) return File2Document(**obj)
if not e:
raise RuntimeError("Database error (File retrieval)!")
return obj
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -63,9 +60,8 @@ class File2DocumentService(CommonService):
def update_by_file_id(cls, file_id, obj): def update_by_file_id(cls, file_id, obj):
obj["update_time"] = current_timestamp() obj["update_time"] = current_timestamp()
obj["update_date"] = datetime_format(datetime.now()) obj["update_date"] = datetime_format(datetime.now())
# num = cls.model.update(obj).where(cls.model.id == file_id).execute() cls.model.update(obj).where(cls.model.id == file_id).execute()
e, obj = cls.get_by_id(cls.model.id) return File2Document(**obj)
return obj
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()

View File

@ -251,10 +251,7 @@ class FileService(CommonService):
def insert(cls, file): def insert(cls, file):
if not cls.save(**file): if not cls.save(**file):
raise RuntimeError("Database error (File)!") raise RuntimeError("Database error (File)!")
e, file = cls.get_by_id(file["id"]) return File(**file)
if not e:
raise RuntimeError("Database error (File retrieval)!")
return file
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()

View File

@ -35,7 +35,10 @@ class KnowledgebaseService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc, keywords): page_number, items_per_page,
orderby, desc, keywords,
parser_id=None
):
fields = [ fields = [
cls.model.id, cls.model.id,
cls.model.avatar, cls.model.avatar,
@ -67,6 +70,8 @@ class KnowledgebaseService(CommonService):
cls.model.tenant_id == user_id)) cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value) & (cls.model.status == StatusEnum.VALID.value)
) )
if parser_id:
kbs = kbs.where(cls.model.parser_id == parser_id)
if desc: if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else: else:

View File

@ -69,6 +69,7 @@ class TaskService(CommonService):
Knowledgebase.language, Knowledgebase.language,
Knowledgebase.embd_id, Knowledgebase.embd_id,
Knowledgebase.pagerank, Knowledgebase.pagerank,
Knowledgebase.parser_config.alias("kb_parser_config"),
Tenant.img2txt_id, Tenant.img2txt_id,
Tenant.asr_id, Tenant.asr_id,
Tenant.llm_id, Tenant.llm_id,

View File

@ -140,7 +140,7 @@ def init_settings():
API_KEY = LLM.get("api_key", "") API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get( PARSERS = LLM.get(
"parsers", "parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag")
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")

View File

@ -173,6 +173,7 @@ def validate_request(*args, **kwargs):
return wrapper return wrapper
def not_allowed_parameters(*params): def not_allowed_parameters(*params):
def decorator(f): def decorator(f):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -182,7 +183,9 @@ def not_allowed_parameters(*params):
return get_json_result( return get_json_result(
code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -207,6 +210,7 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None)
response = {"code": code, "message": message, "data": data} response = {"code": code, "message": message, "data": data}
return jsonify(response) return jsonify(response)
def apikey_required(func): def apikey_required(func):
@wraps(func) @wraps(func)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
@ -282,17 +286,18 @@ def construct_error_response(e):
def token_required(func): def token_required(func):
@wraps(func) @wraps(func)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
authorization_str=flask_request.headers.get('Authorization') authorization_str = flask_request.headers.get('Authorization')
if not authorization_str: if not authorization_str:
return get_json_result(data=False,message="`Authorization` can't be empty") return get_json_result(data=False, message="`Authorization` can't be empty")
authorization_list=authorization_str.split() authorization_list = authorization_str.split()
if len(authorization_list) < 2: if len(authorization_list) < 2:
return get_json_result(data=False,message="Please check your authorization format.") return get_json_result(data=False, message="Please check your authorization format.")
token = authorization_list[1] token = authorization_list[1]
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Authentication error: API key is invalid!', code=settings.RetCode.AUTHENTICATION_ERROR data=False, message='Authentication error: API key is invalid!',
code=settings.RetCode.AUTHENTICATION_ERROR
) )
kwargs['tenant_id'] = objs[0].tenant_id kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs) return func(*args, **kwargs)
@ -330,35 +335,41 @@ def generate_confirmation_token(tenent_id):
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34] return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
def valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method): def valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method):
if valid_parameter(permission,valid_permission): if valid_parameter(permission, valid_permission):
return valid_parameter(permission,valid_permission) return valid_parameter(permission, valid_permission)
if valid_parameter(language,valid_language): if valid_parameter(language, valid_language):
return valid_parameter(language,valid_language) return valid_parameter(language, valid_language)
if valid_parameter(chunk_method,valid_chunk_method): if valid_parameter(chunk_method, valid_chunk_method):
return valid_parameter(chunk_method,valid_chunk_method) return valid_parameter(chunk_method, valid_chunk_method)
def valid_parameter(parameter,valid_values):
def valid_parameter(parameter, valid_values):
if parameter and parameter not in valid_values: if parameter and parameter not in valid_values:
return get_error_data_result(f"'{parameter}' is not in {valid_values}") return get_error_data_result(f"'{parameter}' is not in {valid_values}")
def get_parser_config(chunk_method,parser_config):
def get_parser_config(chunk_method, parser_config):
if parser_config: if parser_config:
return parser_config return parser_config
if not chunk_method: if not chunk_method:
chunk_method = "naive" chunk_method = "naive"
key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"use_raptor": False}}, key_mapping = {
"qa":{"raptor":{"use_raptor":False}}, "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": True,
"resume":None, "raptor": {"use_raptor": False}},
"manual":{"raptor":{"use_raptor":False}}, "qa": {"raptor": {"use_raptor": False}},
"table":None, "tag": None,
"paper":{"raptor":{"use_raptor":False}}, "resume": None,
"book":{"raptor":{"use_raptor":False}}, "manual": {"raptor": {"use_raptor": False}},
"laws":{"raptor":{"use_raptor":False}}, "table": None,
"presentation":{"raptor":{"use_raptor":False}}, "paper": {"raptor": {"use_raptor": False}},
"one":None, "book": {"raptor": {"use_raptor": False}},
"knowledge_graph":{"chunk_token_num":8192,"delimiter":"\\n!?;。;!?","entity_types":["organization","person","location","event","time"]}, "laws": {"raptor": {"use_raptor": False}},
"email":None, "presentation": {"raptor": {"use_raptor": False}},
"picture":None} "one": None,
parser_config=key_mapping[chunk_method] "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?",
return parser_config "entity_types": ["organization", "person", "location", "event", "time"]},
"email": None,
"picture": None}
parser_config = key_mapping[chunk_method]
return parser_config

View File

@ -10,6 +10,7 @@
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
@ -27,5 +28,6 @@
"available_int": {"type": "integer", "default": 1}, "available_int": {"type": "integer", "default": 1},
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"pagerank_fea": {"type": "integer", "default": 0} "pagerank_fea": {"type": "integer", "default": 0},
"tag_fea": {"type": "integer", "default": 0}
} }

View File

@ -111,4 +111,23 @@ def set_embed_cache(llmnm, txt, arr):
k = hasher.hexdigest() k = hasher.hexdigest()
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr) arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600) REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
def get_tags_from_cache(kb_ids):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return bin
def set_tags_to_cache(kb_ids, tags):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)

View File

@ -26,6 +26,7 @@ from docx import Document
from PIL import Image from PIL import Image
from markdown import markdown from markdown import markdown
class Excel(ExcelParser): class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None): def __call__(self, fnm, binary=None, callback=None):
if not binary: if not binary:
@ -58,11 +59,11 @@ class Excel(ExcelParser):
if len(res) % 999 == 0: if len(res) % 999 == 0:
callback(len(res) * callback(len(res) *
0.6 / 0.6 /
total, ("Extract Q&A: {}".format(len(res)) + total, ("Extract pairs: {}".format(len(res)) +
(f"{len(fails)} failure, line: %s..." % (f"{len(fails)} failure, line: %s..." %
(",".join(fails[:3])) if fails else ""))) (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + ( callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
self.is_english = is_english( self.is_english = is_english(
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1]) [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
@ -269,7 +270,7 @@ def beAdocPdf(d, q, a, eng, image, poss):
return d return d
def beAdocDocx(d, q, a, eng, image): def beAdocDocx(d, q, a, eng, image, row_num=-1):
qprefix = "Question: " if eng else "问题:" qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:" aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join( d["content_with_weight"] = "\t".join(
@ -277,16 +278,20 @@ def beAdocDocx(d, q, a, eng, image):
d["content_ltks"] = rag_tokenizer.tokenize(q) d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["image"] = image d["image"] = image
if row_num >= 0:
d["top_int"] = [row_num]
return d return d
def beAdoc(d, q, a, eng): def beAdoc(d, q, a, eng, row_num=-1):
qprefix = "Question: " if eng else "问题:" qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:" aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join( d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) [qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q) d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if row_num >= 0:
d["top_int"] = [row_num]
return d return d
@ -316,8 +321,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
if re.search(r"\.xlsx?$", filename, re.IGNORECASE): if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
excel_parser = Excel() excel_parser = Excel()
for q, a in excel_parser(filename, binary, callback): for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
res.append(beAdoc(deepcopy(doc), q, a, eng)) res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
return res return res
elif re.search(r"\.(txt)$", filename, re.IGNORECASE): elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
@ -344,7 +349,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
fails.append(str(i+1)) fails.append(str(i+1))
elif len(arr) == 2: elif len(arr) == 2:
if question and answer: if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng)) res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
question, answer = arr question, answer = arr
i += 1 i += 1
if len(res) % 999 == 0: if len(res) % 999 == 0:
@ -352,7 +357,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question: if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng)) res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines)))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@ -378,14 +383,14 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
fails.append(str(i + 1)) fails.append(str(i + 1))
elif len(row) == 2: elif len(row) == 2:
if question and answer: if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng)) res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
question, answer = row question, answer = row
if len(res) % 999 == 0: if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question: if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng)) res.append(beAdoc(deepcopy(doc), question, answer, eng, len(reader)))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@ -420,7 +425,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
if last_answer.strip(): if last_answer.strip():
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)
if sum_question: if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng)) res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
last_answer = '' last_answer = ''
i = question_level i = question_level
@ -432,7 +437,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
if last_answer.strip(): if last_answer.strip():
sum_question = '\n'.join(question_stack) sum_question = '\n'.join(question_stack)
if sum_question: if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng)) res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
return res return res
elif re.search(r"\.docx$", filename, re.IGNORECASE): elif re.search(r"\.docx$", filename, re.IGNORECASE):
@ -440,8 +445,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
qai_list, tbls = docx_parser(filename, binary, qai_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback) from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
for q, a, image in qai_list: for i, (q, a, image) in enumerate(qai_list):
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image)) res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
return res return res
raise NotImplementedError( raise NotImplementedError(

125
rag/app/tag.py Normal file
View File

@ -0,0 +1,125 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import csv
from copy import deepcopy
from deepdoc.parser.utils import get_text
from rag.app.qa import Excel
from rag.nlp import rag_tokenizer
def beAdoc(d, q, a, eng, row_num=-1):
d["content_with_weight"] = q
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["tag_kwd"] = [t.strip() for t in a.split(",") if t.strip()]
if row_num >= 0:
d["top_int"] = [row_num]
return d
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column content and tags without header.
And content column is ahead of tags column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
All the deformed lines will be ignored.
Every pair will be treated as a chunk.
"""
eng = lang.lower() == "english"
res = []
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
return res
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
comma, tab = 0, 0
for line in lines:
if len(line.split(",")) == 2:
comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else ","
fails = []
content = ""
i = 0
while i < len(lines):
arr = lines[i].split(delimiter)
if len(arr) != 2:
content += "\n" + lines[i]
elif len(arr) == 2:
content += "\n" + arr[0]
res.append(beAdoc(deepcopy(doc), content, arr[1], eng, i))
content = ""
i += 1
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract TAG: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
delimiter = "\t" if any("\t" in line for line in lines) else ","
fails = []
content = ""
res = []
reader = csv.reader(lines, delimiter=delimiter)
for i, row in enumerate(reader):
if len(row) != 2:
content += "\n" + lines[i]
elif len(row) == 2:
content += "\n" + row[0]
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
content = ""
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract TAG : {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
raise NotImplementedError(
"Excel, csv(txt) format files are supported.")
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -59,13 +59,15 @@ class FulltextQueryer:
"", "",
), ),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", " ") (
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
] ]
for r, p in patts: for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE) txt = re.sub(r, p, txt, flags=re.IGNORECASE)
return txt return txt
def question(self, txt, tbl="qa", min_match:float=0.6): def question(self, txt, tbl="qa", min_match: float = 0.6):
txt = re.sub( txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+", r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ", " ",
@ -90,7 +92,8 @@ class FulltextQueryer:
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()] syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn)) syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)] q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)): for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip() left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right: if not left or not right:
@ -155,7 +158,7 @@ class FulltextQueryer:
if len(keywords) < 32: if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s]) keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns] tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
if len(keywords) >= 32: if len(keywords) >= 32:
break break
@ -174,8 +177,6 @@ class FulltextQueryer:
if len(twts) > 1: if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt) tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
if re.match(r"[0-9a-z ]+$", tt):
tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
syns = " OR ".join( syns = " OR ".join(
[ [
@ -232,3 +233,25 @@ class FulltextQueryer:
for k, v in qtwt.items(): for k, v in qtwt.items():
q += v q += v
return s / q return s / q
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str):
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
tks_w = self.tw.weights(content_tks, preprocess=False)
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if tk:
keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) / 10)})

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging import logging
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.nlp import rag_tokenizer, query from rag.nlp import rag_tokenizer, query
import numpy as np import numpy as np
@ -47,7 +47,8 @@ class Dealer:
qv, _ = emb_mdl.encode_queries(txt) qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape shape = np.array(qv).shape
if len(shape) > 1: if len(shape) > 1:
raise Exception(f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).") raise Exception(
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [float(v) for v in qv] embedding_data = [float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec" vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity}) return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
@ -63,7 +64,12 @@ class Dealer:
condition[key] = req[key] condition[key] = req[key]
return condition return condition
def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=None, highlight = False): def search(self, req, idx_names: str | list[str],
kb_ids: list[str],
emb_mdl=None,
highlight=False,
rank_feature: dict | None = None
):
filters = self.get_filters(req) filters = self.get_filters(req)
orderBy = OrderByExpr() orderBy = OrderByExpr()
@ -72,9 +78,11 @@ class Dealer:
ps = int(req.get("size", topk)) ps = int(req.get("size", topk))
offset, limit = pg * ps, ps offset, limit = pg * ps, ps
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int", src = req.get("fields",
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd", "question_kwd", "question_tks", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
"available_int", "content_with_weight", "pagerank_fea"]) "doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
"question_kwd", "question_tks",
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
kwds = set([]) kwds = set([])
qst = req.get("question", "") qst = req.get("question", "")
@ -85,15 +93,16 @@ class Dealer:
orderBy.asc("top_int") orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt") orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res) total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
else: else:
highlightFields = ["content_ltks", "title_tks"] if highlight else [] highlightFields = ["content_ltks", "title_tks"] if highlight else []
matchText, keywords = self.qryr.question(qst, min_match=0.3) matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None: if emb_mdl is None:
matchExprs = [matchText] matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
total=self.dataStore.getTotal(res) idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
else: else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
@ -103,8 +112,9 @@ class Dealer:
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"}) fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
matchExprs = [matchText, matchDense, fusionExpr] matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
total=self.dataStore.getTotal(res) idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match # If result is empty, try again with lower min_match
@ -112,8 +122,9 @@ class Dealer:
matchText, _ = self.qryr.question(qst, min_match=0.1) matchText, _ = self.qryr.question(qst, min_match=0.1)
filters.pop("doc_ids", None) filters.pop("doc_ids", None)
matchDense.extra_options["similarity"] = 0.17 matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
total=self.dataStore.getTotal(res) orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total)) logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords: for k in keywords:
@ -126,8 +137,8 @@ class Dealer:
kwds.add(kk) kwds.add(kk)
logging.debug(f"TOTAL: {total}") logging.debug(f"TOTAL: {total}")
ids=self.dataStore.getChunkIds(res) ids = self.dataStore.getChunkIds(res)
keywords=list(kwds) keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd") aggs = self.dataStore.getAggregation(res, "docnm_kwd")
return self.SearchResult( return self.SearchResult(
@ -188,13 +199,13 @@ class Dealer:
ans_v, _ = embd_mdl.encode(pieces_) ans_v, _ = embd_mdl.encode(pieces_)
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0])) len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split() chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
for ck in chunks] for ck in chunks]
cites = {} cites = {}
thr = 0.63 thr = 0.63
while thr>0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks: while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
for i, a in enumerate(pieces_): for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
chunk_v, chunk_v,
@ -228,20 +239,44 @@ class Dealer:
return res, seted return res, seted
def _rank_feature_scores(self, query_rfea, search_res):
## For rank feature(tag_fea) scores.
rank_fea = []
pageranks = []
for chunk_id in search_res.ids:
pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0))
pageranks = np.array(pageranks, dtype=float)
if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
if t in query_rfea:
nor += query_rfea[t] * sc
denor += sc * sc
if denor == 0:
rank_fea.append(0)
else:
rank_fea.append(nor/np.sqrt(denor)/q_denor)
return np.array(rank_fea)*10. + pageranks
def rerank(self, sres, query, tkweight=0.3, def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"): vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None
):
_, keywords = self.qryr.question(query) _, keywords = self.qryr.question(query)
vector_size = len(sres.query_vector) vector_size = len(sres.query_vector)
vector_column = f"q_{vector_size}_vec" vector_column = f"q_{vector_size}_vec"
zero_vector = [0.0] * vector_size zero_vector = [0.0] * vector_size
ins_embd = [] ins_embd = []
pageranks = []
for chunk_id in sres.ids: for chunk_id in sres.ids:
vector = sres.field[chunk_id].get(vector_column, zero_vector) vector = sres.field[chunk_id].get(vector_column, zero_vector)
if isinstance(vector, str): if isinstance(vector, str):
vector = [float(v) for v in vector.split("\t")] vector = [float(v) for v in vector.split("\t")]
ins_embd.append(vector) ins_embd.append(vector)
pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
if not ins_embd: if not ins_embd:
return [], [], [] return [], [], []
@ -254,18 +289,22 @@ class Dealer:
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t] title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t] question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", []) important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks*2 + important_kwd*5 + question_tks*6 tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6
ins_tw.append(tks) ins_tw.append(tks)
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd, ins_embd,
keywords, keywords,
ins_tw, tkweight, vtweight) ins_tw, tkweight, vtweight)
return sim+np.array(pageranks, dtype=float), tksim, vtsim return sim + rank_fea, tksim, vtsim
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"): vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None):
_, keywords = self.qryr.question(query) _, keywords = self.qryr.question(query)
for i in sres.ids: for i in sres.ids:
@ -280,9 +319,11 @@ class Dealer:
ins_tw.append(tks) ins_tw.append(tks)
tksim = self.qryr.token_similarity(keywords, ins_tw) tksim = self.qryr.token_similarity(keywords, ins_tw)
vtsim,_ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw]) vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd, return self.qryr.hybrid_similarity(ans_embd,
@ -291,13 +332,15 @@ class Dealer:
rag_tokenizer.tokenize(inst).split()) rag_tokenizer.tokenize(inst).split())
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2, def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False): vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
rerank_mdl=None, highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10}):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}} ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: if not question:
return ranks return ranks
RERANK_PAGE_LIMIT = 3 RERANK_PAGE_LIMIT = 3
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128), req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size * RERANK_PAGE_LIMIT, 128),
"question": question, "vector": True, "topk": top, "question": question, "vector": True, "topk": top,
"similarity": similarity_threshold, "similarity": similarity_threshold,
"available_int": 1} "available_int": 1}
@ -309,29 +352,30 @@ class Dealer:
if isinstance(tenant_ids, str): if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",") tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight) sres = self.search(req, [index_name(tid) for tid in tenant_ids],
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
ranks["total"] = sres.total ranks["total"] = sres.total
if page <= RERANK_PAGE_LIMIT: if page <= RERANK_PAGE_LIMIT:
if rerank_mdl and sres.total > 0: if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl, sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight, vector_similarity_weight) sres, question, 1 - vector_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature)
else: else:
sim, tsim, vsim = self.rerank( sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight) sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size] rank_feature=rank_feature)
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
else: else:
sim = tsim = vsim = [1]*len(sres.ids) sim = tsim = vsim = [1] * len(sres.ids)
idx = list(range(len(sres.ids))) idx = list(range(len(sres.ids)))
def floor_sim(score):
return (int(score * 100.)%100)/100.
dim = len(sres.query_vector) dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec" vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim zero_vector = [0.0] * dim
for i in idx: for i in idx:
if floor_sim(sim[i]) < similarity_threshold: if sim[i] < similarity_threshold:
break break
if len(ranks["chunks"]) >= page_size: if len(ranks["chunks"]) >= page_size:
if aggs: if aggs:
@ -369,8 +413,8 @@ class Dealer:
ranks["doc_aggs"] = [{"doc_name": k, ranks["doc_aggs"] = [{"doc_name": k,
"doc_id": v["doc_id"], "doc_id": v["doc_id"],
"count": v["count"]} for k, "count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(), v in sorted(ranks["doc_aggs"].items(),
key=lambda x:x[1]["count"] * -1)] key=lambda x: x[1]["count"] * -1)]
ranks["chunks"] = ranks["chunks"][:page_size] ranks["chunks"] = ranks["chunks"][:page_size]
return ranks return ranks
@ -379,15 +423,57 @@ class Dealer:
tbl = self.dataStore.sql(sql, fetch_size, format) tbl = self.dataStore.sql(sql, fetch_size, format)
return tbl return tbl
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]): def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"]):
condition = {"doc_id": doc_id} condition = {"doc_id": doc_id}
res = [] res = []
bs = 128 bs = 128
for p in range(0, max_count, bs): for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), kb_ids) es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields) dict_chunks = self.dataStore.getFields(es_res, fields)
if dict_chunks: if dict_chunks:
res.extend(dict_chunks.values()) res.extend(dict_chunks.values())
if len(dict_chunks.values()) < bs: if len(dict_chunks.values()) < bs:
break break
return res return res
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd")
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd")
total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res}
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0}
return True
def tag_query(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000):
if isinstance(tenant_ids, str):
idx_nms = index_name(tenant_ids)
else:
idx_nms = [index_name(tid) for tid in tenant_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a: c for a, c in tag_fea if c > 0}

View File

@ -38,6 +38,9 @@ SVR_QUEUE_RETENTION = 60*60
SVR_QUEUE_MAX_LEN = 1024 SVR_QUEUE_MAX_LEN = 1024
SVR_CONSUMER_NAME = "rag_flow_svr_consumer" SVR_CONSUMER_NAME = "rag_flow_svr_consumer"
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group" SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group"
PAGERANK_FLD = "pagerank_fea"
TAG_FLD = "tag_feas"
def print_rag_settings(): def print_rag_settings():
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")

View File

@ -16,10 +16,10 @@
# from beartype import BeartypeConf # from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this # from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import random
import sys import sys
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger
from graphrag.utils import get_llm_cache, set_llm_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO CONSUMER_NAME = "task_executor_" + CONSUMER_NO
@ -44,7 +44,7 @@ import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
from api.db import LLMType, ParserType, TaskStatus from api.db import LLMType, ParserType, TaskStatus
from api.db.services.dialog_service import keyword_extraction, question_proposal from api.db.services.dialog_service import keyword_extraction, question_proposal, content_tagging
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
@ -53,10 +53,10 @@ from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.db.db_models import close_connection from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email knowledge_graph, email, tag
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
@ -78,7 +78,8 @@ FACTORY = {
ParserType.ONE.value: one, ParserType.ONE.value: one,
ParserType.AUDIO.value: audio, ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email, ParserType.EMAIL.value: email,
ParserType.KG.value: knowledge_graph ParserType.KG.value: knowledge_graph,
ParserType.TAG.value: tag
} }
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
@ -199,7 +200,8 @@ def build_chunks(task, progress_callback):
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"])) logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError: except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"])) logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
raise raise
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
@ -227,7 +229,7 @@ def build_chunks(task, progress_callback):
"kb_id": str(task["kb_id"]) "kb_id": str(task["kb_id"])
} }
if task["pagerank"]: if task["pagerank"]:
doc["pagerank_fea"] = int(task["pagerank"]) doc[PAGERANK_FLD] = int(task["pagerank"])
el = 0 el = 0
for ck in cks: for ck in cks:
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
@ -252,7 +254,8 @@ def build_chunks(task, progress_callback):
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()) STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st el += timer() - st
except Exception: except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise raise
d["img_id"] = "{}-{}".format(task["kb_id"], d["id"]) d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
@ -295,12 +298,43 @@ def build_chunks(task, progress_callback):
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st)) progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
progress_callback(msg="Start to tag for every chunk ...")
kb_ids = task["kb_parser_config"]["tag_kb_ids"]
tenant_id = task["tenant_id"]
topn_tags = task["kb_parser_config"].get("topn_tags", 3)
S = 1000
st = timer()
examples = []
all_tags = get_tags_from_cache(kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags)
else:
all_tags = json.loads(all_tags)
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs:
if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
continue
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags,
random.choices(examples, k=2) if len(examples)>2 else examples,
topn=topn_tags)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
return docs return docs
def init_kb(row, vector_size: int): def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"]) idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size) return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
def embedding(docs, mdl, parser_config=None, callback=None): def embedding(docs, mdl, parser_config=None, callback=None):
@ -381,7 +415,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
"title_tks": rag_tokenizer.tokenize(row["name"]) "title_tks": rag_tokenizer.tokenize(row["name"])
} }
if row["pagerank"]: if row["pagerank"]:
doc["pagerank_fea"] = int(row["pagerank"]) doc[PAGERANK_FLD] = int(row["pagerank"])
res = [] res = []
tk_count = 0 tk_count = 0
for content, vctr in chunks[original_length:]: for content, vctr in chunks[original_length:]:
@ -480,7 +514,8 @@ def do_handle_task(task):
doc_store_result = "" doc_store_result = ""
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id) doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
task_dataset_id)
if b % 128 == 0: if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result: if doc_store_result:
@ -493,15 +528,21 @@ def do_handle_task(task):
TaskService.update_chunk_ids(task["id"], chunk_ids_str) TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id) doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
task_dataset_id)
return return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts)) logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
time_cost = timer() - start_ts time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost)) progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost)) logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
token_count, time_cost))
def handle_task(): def handle_task():

View File

@ -71,11 +71,13 @@ def findMaxTm(fnm):
pass pass
return m return m
tiktoken_cache_dir = get_project_base_directory() tiktoken_cache_dir = get_project_base_directory()
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") # encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
encoder = tiktoken.get_encoding("cl100k_base") encoder = tiktoken.get_encoding("cl100k_base")
def num_tokens_from_string(string: str) -> int: def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string.""" """Returns the number of tokens in a text string."""
try: try:

View File

@ -176,7 +176,17 @@ class DocStoreConnection(ABC):
@abstractmethod @abstractmethod
def search( def search(
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str] self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
) -> list[dict] | pl.DataFrame: ) -> list[dict] | pl.DataFrame:
""" """
Search with given conjunctive equivalent filtering condition and return all fields of matched documents Search with given conjunctive equivalent filtering condition and return all fields of matched documents
@ -191,7 +201,7 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented") raise NotImplementedError("Not implemented")
@abstractmethod @abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]: def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
""" """
Update or insert a bulk of rows Update or insert a bulk of rows
""" """

View File

@ -9,6 +9,7 @@ from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout from elastic_transport import ConnectionTimeout
from rag import settings from rag import settings
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.utils import singleton from rag.utils import singleton
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
import polars as pl import polars as pl
@ -20,6 +21,7 @@ ATTEMPT_TIME = 2
logger = logging.getLogger('ragflow.es_conn') logger = logging.getLogger('ragflow.es_conn')
@singleton @singleton
class ESConnection(DocStoreConnection): class ESConnection(DocStoreConnection):
def __init__(self): def __init__(self):
@ -111,9 +113,19 @@ class ESConnection(DocStoreConnection):
CRUD operations CRUD operations
""" """
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], def search(
orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str], self, selectFields: list[str],
knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame: highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
) -> list[dict] | pl.DataFrame:
""" """
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
""" """
@ -175,8 +187,13 @@ class ESConnection(DocStoreConnection):
similarity=similarity, similarity=similarity,
) )
if bqry and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry: if bqry:
bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
s = s.query(bqry) s = s.query(bqry)
for field in highlightFields: for field in highlightFields:
s = s.highlight(field) s = s.highlight(field)
@ -187,7 +204,7 @@ class ESConnection(DocStoreConnection):
order = "asc" if order == 0 else "desc" order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]: if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float", order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"} "mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"): elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"} order_info = {"order": order, "unmapped_type": "float"}
else: else:
@ -195,8 +212,11 @@ class ESConnection(DocStoreConnection):
orders.append({field: order_info}) orders.append({field: order_info})
s = s.sort(*orders) s = s.sort(*orders)
for fld in aggFields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0: if limit > 0:
s = s[offset:offset+limit] s = s[offset:offset + limit]
q = s.to_dict() q = s.to_dict()
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q)) logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
@ -240,7 +260,7 @@ class ESConnection(DocStoreConnection):
logger.error("ESConnection.get timeout for 3 times!") logger.error("ESConnection.get timeout for 3 times!")
raise Exception("ESConnection.get timeout.") raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]: def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = [] operations = []
for d in documents: for d in documents:
@ -292,44 +312,57 @@ class ESConnection(DocStoreConnection):
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
return False return False
else:
# update unspecific maybe-multiple documents # update unspecific maybe-multiple documents
bqry = Q("bool") bqry = Q("bool")
for k, v in condition.items(): for k, v in condition.items():
if not isinstance(k, str) or not v: if not isinstance(k, str) or not v:
continue continue
if k == "exist": if k == "exist":
bqry.filter.append(Q("exists", field=v)) bqry.filter.append(Q("exists", field=v))
continue continue
if isinstance(v, list): if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v})) bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int): elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v})) bqry.filter.append(Q("term", **{k: v}))
else: else:
raise Exception( raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = [] scripts = []
for k, v in newValue.items(): params = {}
if k == "remove": for k, v in newValue.items():
scripts.append(f"ctx._source.remove('{v}');") if k == "remove":
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str): if isinstance(v, str):
scripts.append(f"ctx._source.{k} = '{v}'") scripts.append(f"ctx._source.remove('{v}');")
elif isinstance(v, int): if isinstance(v, dict):
scripts.append(f"ctx._source.{k} = {v}") for kk, vv in v.items():
else: scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
raise Exception( params[f"p_{kk}"] = vv
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
scripts.append(f"ctx._source.{k} = '{v}'")
elif isinstance(v, int):
scripts.append(f"ctx._source.{k} = {v}")
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery( ubq = UpdateByQuery(
index=indexName).using( index=indexName).using(
self.es).query(bqry) self.es).query(bqry)
ubq = ubq.script(source="; ".join(scripts)) ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True) ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5) ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed") ubq = ubq.params(conflicts="proceed")
for i in range(3):
for _ in range(ATTEMPT_TIME):
try: try:
_ = ubq.execute() _ = ubq.execute()
return True return True

View File

@ -10,6 +10,7 @@ from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode from infinity.errors import ErrorCode
from rag import settings from rag import settings
from rag.settings import PAGERANK_FLD
from rag.utils import singleton from rag.utils import singleton
import polars as pl import polars as pl
from polars.series.series import Series from polars.series.series import Series
@ -231,8 +232,7 @@ class InfinityConnection(DocStoreConnection):
""" """
def search( def search(
self, self, selectFields: list[str],
selectFields: list[str],
highlightFields: list[str], highlightFields: list[str],
condition: dict, condition: dict,
matchExprs: list[MatchExpr], matchExprs: list[MatchExpr],
@ -241,7 +241,9 @@ class InfinityConnection(DocStoreConnection):
limit: int, limit: int,
indexNames: str | list[str], indexNames: str | list[str],
knowledgebaseIds: list[str], knowledgebaseIds: list[str],
) -> tuple[pl.DataFrame, int]: aggFields: list[str] = [],
rank_feature: dict | None = None
) -> list[dict] | pl.DataFrame:
""" """
TODO: Infinity doesn't provide highlight TODO: Infinity doesn't provide highlight
""" """
@ -256,7 +258,7 @@ class InfinityConnection(DocStoreConnection):
if essential_field not in selectFields: if essential_field not in selectFields:
selectFields.append(essential_field) selectFields.append(essential_field)
if matchExprs: if matchExprs:
for essential_field in ["score()", "pagerank_fea"]: for essential_field in ["score()", PAGERANK_FLD]:
selectFields.append(essential_field) selectFields.append(essential_field)
# Prepare expressions common to all tables # Prepare expressions common to all tables
@ -346,7 +348,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, selectFields) res = concat_dataframes(df_list, selectFields)
if matchExprs: if matchExprs:
res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True) res = res.sort(pl.col("SCORE") + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
res = res.limit(limit) res = res.limit(limit)
logger.debug(f"INFINITY search final result: {str(res)}") logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count return res, total_hits_count
@ -378,7 +380,7 @@ class InfinityConnection(DocStoreConnection):
return res_fields.get(chunkId, None) return res_fields.get(chunkId, None)
def insert( def insert(
self, documents: list[dict], indexName: str, knowledgebaseId: str self, documents: list[dict], indexName: str, knowledgebaseId: str = None
) -> list[str]: ) -> list[str]:
inf_conn = self.connPool.get_conn() inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
@ -456,7 +458,7 @@ class InfinityConnection(DocStoreConnection):
elif k in ["page_num_int", "top_int"]: elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list) assert isinstance(v, list)
newValue[k] = "_".join(f"{num:08x}" for num in v) newValue[k] = "_".join(f"{num:08x}" for num in v)
elif k == "remove" and v in ["pagerank_fea"]: elif k == "remove" and v in [PAGERANK_FLD]:
del newValue[k] del newValue[k]
newValue[v] = 0 newValue[v] = 0
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.") logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")

View File

@ -27,7 +27,7 @@ def test_create_dataset_with_invalid_parameter(get_api_key_fixture):
API_KEY = get_api_key_fixture API_KEY = get_api_key_fixture
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
valid_chunk_methods = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", valid_chunk_methods = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one",
"knowledge_graph", "email"] "knowledge_graph", "email", "tag"]
chunk_method = "invalid_chunk_method" chunk_method = "invalid_chunk_method"
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
rag.create_dataset("test_create_dataset_with_invalid_chunk_method",chunk_method=chunk_method) rag.create_dataset("test_create_dataset_with_invalid_chunk_method",chunk_method=chunk_method)