### What problem does this PR solve?

#4367

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2025-01-09 17:07:21 +08:00 committed by GitHub
parent f892d7d426
commit c5da3cdd97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 736 additions and 202 deletions

View File

@ -19,6 +19,7 @@ from abc import ABC
import pandas as pd
from api.db import LLMType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
@ -70,7 +71,8 @@ class Retrieval(ComponentBase, ABC):
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl)
aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs))
if not kbinfos["chunks"]:
df = Retrieval.be_output("")

View File

@ -25,7 +25,7 @@ from api.db import FileType, LLMType, ParserType, FileSource
from api.db.db_models import APIToken, Task, File
from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat, keyword_extraction
from api.db.services.dialog_service import DialogService, chat, keyword_extraction, label_question
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
@ -840,7 +840,8 @@ def retrieval():
question += keyword_extraction(chat_mdl, question)
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
doc_ids, rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs))
for c in ranks["chunks"]:
c.pop("vector", None)
return get_json_result(data=ranks)

View File

@ -19,9 +19,10 @@ import json
from flask import request
from flask_login import login_required, current_user
from api.db.services.dialog_service import keyword_extraction
from api.db.services.dialog_service import keyword_extraction, label_question
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, rag_tokenizer
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -124,10 +125,14 @@ def set():
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
d["question_kwd"] = req["question_kwd"]
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
if req.get("important_kwd"):
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
if req.get("question_kwd"):
d["question_kwd"] = req["question_kwd"]
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
if req.get("tag_kwd"):
d["tag_kwd"] = req["tag_kwd"]
if "available_int" in req:
d["available_int"] = req["available_int"]
@ -220,7 +225,7 @@ def create():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
d["kb_id"] = doc.kb_id
d["kb_id"] = [doc.kb_id]
d["docnm_kwd"] = doc.name
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
d["doc_id"] = doc.id
@ -233,7 +238,7 @@ def create():
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank:
d["pagerank_fea"] = kb.pagerank
d[PAGERANK_FLD] = kb.pagerank
embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
@ -294,12 +299,16 @@ def retrieval_test():
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
rank_feature=labels
)
for c in ranks["chunks"]:
c.pop("vector", None)
ranks["labels"] = labels
return get_json_result(data=ranks)
except Exception as e:

View File

@ -25,7 +25,7 @@ from flask import request, Response
from flask_login import login_required, current_user
from api.db import LLMType
from api.db.services.dialog_service import DialogService, chat, ask
from api.db.services.dialog_service import DialogService, chat, ask, label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api import settings
@ -379,8 +379,11 @@ def mindmap():
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False)
question = req["question"]
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False,
rank_feature=label_question(question, [kb])
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
if "error" in mind_map:

View File

@ -30,6 +30,7 @@ from api.utils.api_utils import get_json_result
from api import settings
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD
@manager.route('/create', methods=['post']) # noqa: F821
@ -104,11 +105,11 @@ def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires pagerank_fea be non-zero!
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
e, kb = KnowledgebaseService.get_by_id(kb.id)
@ -150,12 +151,14 @@ def list_kbs():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 150))
parser_id = request.args.get("parser_id")
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
kbs, total = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords)
[m["tenant_id"] for m in tenants], current_user.id, page_number,
items_per_page, orderby, desc, keywords, parser_id)
return get_json_result(data={"kbs": kbs, "total": total})
except Exception as e:
return server_error_response(e)
@ -199,3 +202,72 @@ def rm():
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/<kb_id>/tags', methods=['GET']) # noqa: F821
@login_required
def list_tags(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
tags = settings.retrievaler.all_tags(current_user.id, [kb_id])
return get_json_result(data=tags)
@manager.route('/tags', methods=['GET']) # noqa: F821
@login_required
def list_tags_from_kbs():
kb_ids = request.args.get("kb_ids", "").split(",")
for kb_id in kb_ids:
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
tags = settings.retrievaler.all_tags(current_user.id, kb_ids)
return get_json_result(data=tags)
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
@login_required
def rm_tags(kb_id):
req = request.json
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
for t in req["tags"]:
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
{"remove": {"tag_kwd": t}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
@login_required
def rename_tags(kb_id):
req = request.json
if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id),
kb_id)
return get_json_result(data=True)

View File

@ -73,7 +73,8 @@ def create(tenant_id):
chunk_method:
type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
"presentation", "picture", "one", "knowledge_graph", "email"]
"presentation", "picture", "one", "knowledge_graph", "email", "tag"
]
description: Chunking method.
parser_config:
type: object
@ -108,6 +109,7 @@ def create(tenant_id):
"one",
"knowledge_graph",
"email",
"tag"
]
check_validation = valid(
permission,
@ -302,7 +304,8 @@ def update(tenant_id, dataset_id):
chunk_method:
type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
"presentation", "picture", "one", "knowledge_graph", "email"]
"presentation", "picture", "one", "knowledge_graph", "email", "tag"
]
description: Updated chunking method.
parser_config:
type: object
@ -339,6 +342,7 @@ def update(tenant_id, dataset_id):
"one",
"knowledge_graph",
"email",
"tag"
]
check_validation = valid(
permission,

View File

@ -16,6 +16,7 @@
from flask import request, jsonify
from api.db import LLMType, ParserType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
@ -54,7 +55,8 @@ def retrieval(tenant_id):
page_size=top,
similarity_threshold=similarity_threshold,
vector_similarity_weight=0.3,
top=top
top=top,
rank_feature=label_question(question, [kb])
)
records = []
for c in ranks["chunks"]:

View File

@ -16,7 +16,7 @@
import pathlib
import datetime
from api.db.services.dialog_service import keyword_extraction
from api.db.services.dialog_service import keyword_extraction, label_question
from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType
@ -276,6 +276,7 @@ def update_doc(tenant_id, dataset_id, document_id):
"one",
"knowledge_graph",
"email",
"tag"
}
if req.get("chunk_method") not in valid_chunk_method:
return get_error_data_result(
@ -1355,6 +1356,7 @@ def retrieval_test(tenant_id):
doc_ids,
rerank_mdl=rerank_mdl,
highlight=highlight,
rank_feature=label_question(question, kbs)
)
for c in ranks["chunks"]:
c.pop("vector", None)

View File

@ -89,6 +89,7 @@ class ParserType(StrEnum):
AUDIO = "audio"
EMAIL = "email"
KG = "knowledge_graph"
TAG = "tag"
class FileSource(StrEnum):

View File

@ -133,7 +133,7 @@ def init_llm_factory():
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
TenantService.filter_update([1 == 1], {
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"})
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
## insert openai two embedding models to the current openai user.
# print("Start to insert 2 OpenAI embedding models...")
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
@ -153,14 +153,7 @@ def init_llm_factory():
break
for kb_id in KnowledgebaseService.get_all_ids():
KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
"""
drop table llm;
drop table llm_factories;
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
alter table knowledgebase modify avatar longtext;
alter table user modify avatar longtext;
alter table dialog modify icon longtext;
"""
def add_graph_templates():

View File

@ -29,8 +29,10 @@ from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api import settings
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name
from rag.settings import TAG_FLD
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
@ -135,6 +137,29 @@ def kb_prompt(kbinfos, max_tokens):
return knowledges
def label_question(question, kbs):
tags = None
tag_kb_ids = []
for kb in kbs:
if kb.parser_config.get("tag_kb_ids"):
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(all_tags, tag_kb_ids)
else:
all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
tags = settings.retrievaler.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,
kb.parser_config.get("topn_tags", 3)
)
return tags
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
@ -236,11 +261,14 @@ def chat(dialog, messages, stream=True, **kwargs):
generate_keyword_ts = timer()
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs)
)
retrieval_ts = timer()
@ -650,7 +678,10 @@ def ask(question, kb_ids, tenant_id):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
1, 12, 0.1, 0.3, aggs=False,
rank_feature=label_question(question, kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
Role: You're a smart assistant. Your name is Miss R.
@ -700,3 +731,56 @@ def ask(question, kb_ids, tenant_id):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
prompt = f"""
Role: You're a text analyzer.
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
Steps::
- Comprehend the tag/label set.
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
Requirements
- The tags MUST be from the tag set.
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
- The relevance score must be range from 1 to 10.
- Keywords ONLY in output.
# TAG SET
{", ".join(all_tags)}
"""
for i, ex in enumerate(examples):
prompt += """
# Examples {}
### Text Content
{}
Output:
{}
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
prompt += f"""
# Real Data
### Text Content
{content}
"""
msg = [
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
kwd = re.sub(r".*?\{", "{", kwd)
return json.loads(kwd)

View File

@ -43,10 +43,7 @@ class File2DocumentService(CommonService):
def insert(cls, obj):
if not cls.save(**obj):
raise RuntimeError("Database error (File)!")
e, obj = cls.get_by_id(obj["id"])
if not e:
raise RuntimeError("Database error (File retrieval)!")
return obj
return File2Document(**obj)
@classmethod
@DB.connection_context()
@ -63,9 +60,8 @@ class File2DocumentService(CommonService):
def update_by_file_id(cls, file_id, obj):
obj["update_time"] = current_timestamp()
obj["update_date"] = datetime_format(datetime.now())
# num = cls.model.update(obj).where(cls.model.id == file_id).execute()
e, obj = cls.get_by_id(cls.model.id)
return obj
cls.model.update(obj).where(cls.model.id == file_id).execute()
return File2Document(**obj)
@classmethod
@DB.connection_context()

View File

@ -251,10 +251,7 @@ class FileService(CommonService):
def insert(cls, file):
if not cls.save(**file):
raise RuntimeError("Database error (File)!")
e, file = cls.get_by_id(file["id"])
if not e:
raise RuntimeError("Database error (File retrieval)!")
return file
return File(**file)
@classmethod
@DB.connection_context()

View File

@ -35,7 +35,10 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc, keywords):
page_number, items_per_page,
orderby, desc, keywords,
parser_id=None
):
fields = [
cls.model.id,
cls.model.avatar,
@ -67,6 +70,8 @@ class KnowledgebaseService(CommonService):
cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if parser_id:
kbs = kbs.where(cls.model.parser_id == parser_id)
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:

View File

@ -69,6 +69,7 @@ class TaskService(CommonService):
Knowledgebase.language,
Knowledgebase.embd_id,
Knowledgebase.pagerank,
Knowledgebase.parser_config.alias("kb_parser_config"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,

View File

@ -140,7 +140,7 @@ def init_settings():
API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get(
"parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag")
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")

View File

@ -173,6 +173,7 @@ def validate_request(*args, **kwargs):
return wrapper
def not_allowed_parameters(*params):
def decorator(f):
def wrapper(*args, **kwargs):
@ -182,7 +183,9 @@ def not_allowed_parameters(*params):
return get_json_result(
code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
return f(*args, **kwargs)
return wrapper
return decorator
@ -207,6 +210,7 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None)
response = {"code": code, "message": message, "data": data}
return jsonify(response)
def apikey_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
@ -282,17 +286,18 @@ def construct_error_response(e):
def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
authorization_str=flask_request.headers.get('Authorization')
authorization_str = flask_request.headers.get('Authorization')
if not authorization_str:
return get_json_result(data=False,message="`Authorization` can't be empty")
authorization_list=authorization_str.split()
return get_json_result(data=False, message="`Authorization` can't be empty")
authorization_list = authorization_str.split()
if len(authorization_list) < 2:
return get_json_result(data=False,message="Please check your authorization format.")
return get_json_result(data=False, message="Please check your authorization format.")
token = authorization_list[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!', code=settings.RetCode.AUTHENTICATION_ERROR
data=False, message='Authentication error: API key is invalid!',
code=settings.RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)
@ -330,35 +335,41 @@ def generate_confirmation_token(tenent_id):
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
def valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method):
if valid_parameter(permission,valid_permission):
return valid_parameter(permission,valid_permission)
if valid_parameter(language,valid_language):
return valid_parameter(language,valid_language)
if valid_parameter(chunk_method,valid_chunk_method):
return valid_parameter(chunk_method,valid_chunk_method)
def valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method):
if valid_parameter(permission, valid_permission):
return valid_parameter(permission, valid_permission)
if valid_parameter(language, valid_language):
return valid_parameter(language, valid_language)
if valid_parameter(chunk_method, valid_chunk_method):
return valid_parameter(chunk_method, valid_chunk_method)
def valid_parameter(parameter,valid_values):
def valid_parameter(parameter, valid_values):
if parameter and parameter not in valid_values:
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
def get_parser_config(chunk_method,parser_config):
def get_parser_config(chunk_method, parser_config):
if parser_config:
return parser_config
if not chunk_method:
chunk_method = "naive"
key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"use_raptor": False}},
"qa":{"raptor":{"use_raptor":False}},
"resume":None,
"manual":{"raptor":{"use_raptor":False}},
"table":None,
"paper":{"raptor":{"use_raptor":False}},
"book":{"raptor":{"use_raptor":False}},
"laws":{"raptor":{"use_raptor":False}},
"presentation":{"raptor":{"use_raptor":False}},
"one":None,
"knowledge_graph":{"chunk_token_num":8192,"delimiter":"\\n!?;。;!?","entity_types":["organization","person","location","event","time"]},
"email":None,
"picture":None}
parser_config=key_mapping[chunk_method]
return parser_config
key_mapping = {
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": True,
"raptor": {"use_raptor": False}},
"qa": {"raptor": {"use_raptor": False}},
"tag": None,
"resume": None,
"manual": {"raptor": {"use_raptor": False}},
"table": None,
"paper": {"raptor": {"use_raptor": False}},
"book": {"raptor": {"use_raptor": False}},
"laws": {"raptor": {"use_raptor": False}},
"presentation": {"raptor": {"use_raptor": False}},
"one": None,
"knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?",
"entity_types": ["organization", "person", "location", "event", "time"]},
"email": None,
"picture": None}
parser_config = key_mapping[chunk_method]
return parser_config

View File

@ -10,6 +10,7 @@
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
@ -27,5 +28,6 @@
"available_int": {"type": "integer", "default": 1},
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"pagerank_fea": {"type": "integer", "default": 0}
"pagerank_fea": {"type": "integer", "default": 0},
"tag_fea": {"type": "integer", "default": 0}
}

View File

@ -111,4 +111,23 @@ def set_embed_cache(llmnm, txt, arr):
k = hasher.hexdigest()
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
def get_tags_from_cache(kb_ids):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return bin
def set_tags_to_cache(kb_ids, tags):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)

View File

@ -26,6 +26,7 @@ from docx import Document
from PIL import Image
from markdown import markdown
class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None):
if not binary:
@ -58,11 +59,11 @@ class Excel(ExcelParser):
if len(res) % 999 == 0:
callback(len(res) *
0.6 /
total, ("Extract Q&A: {}".format(len(res)) +
total, ("Extract pairs: {}".format(len(res)) +
(f"{len(fails)} failure, line: %s..." %
(",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + (
callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
self.is_english = is_english(
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
@ -269,7 +270,7 @@ def beAdocPdf(d, q, a, eng, image, poss):
return d
def beAdocDocx(d, q, a, eng, image):
def beAdocDocx(d, q, a, eng, image, row_num=-1):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
@ -277,16 +278,20 @@ def beAdocDocx(d, q, a, eng, image):
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["image"] = image
if row_num >= 0:
d["top_int"] = [row_num]
return d
def beAdoc(d, q, a, eng):
def beAdoc(d, q, a, eng, row_num=-1):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if row_num >= 0:
d["top_int"] = [row_num]
return d
@ -316,8 +321,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
for q, a in excel_parser(filename, binary, callback):
res.append(beAdoc(deepcopy(doc), q, a, eng))
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
return res
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
@ -344,7 +349,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
fails.append(str(i+1))
elif len(arr) == 2:
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
question, answer = arr
i += 1
if len(res) % 999 == 0:
@ -352,7 +357,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines)))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@ -378,14 +383,14 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
fails.append(str(i + 1))
elif len(row) == 2:
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
question, answer = row
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(reader)))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@ -420,7 +425,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng))
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
last_answer = ''
i = question_level
@ -432,7 +437,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng))
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
return res
elif re.search(r"\.docx$", filename, re.IGNORECASE):
@ -440,8 +445,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
qai_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng)
for q, a, image in qai_list:
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image))
for i, (q, a, image) in enumerate(qai_list):
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
return res
raise NotImplementedError(

125
rag/app/tag.py Normal file
View File

@ -0,0 +1,125 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import csv
from copy import deepcopy
from deepdoc.parser.utils import get_text
from rag.app.qa import Excel
from rag.nlp import rag_tokenizer
def beAdoc(d, q, a, eng, row_num=-1):
d["content_with_weight"] = q
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["tag_kwd"] = [t.strip() for t in a.split(",") if t.strip()]
if row_num >= 0:
d["top_int"] = [row_num]
return d
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column content and tags without header.
And content column is ahead of tags column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
All the deformed lines will be ignored.
Every pair will be treated as a chunk.
"""
eng = lang.lower() == "english"
res = []
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
return res
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
comma, tab = 0, 0
for line in lines:
if len(line.split(",")) == 2:
comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else ","
fails = []
content = ""
i = 0
while i < len(lines):
arr = lines[i].split(delimiter)
if len(arr) != 2:
content += "\n" + lines[i]
elif len(arr) == 2:
content += "\n" + arr[0]
res.append(beAdoc(deepcopy(doc), content, arr[1], eng, i))
content = ""
i += 1
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract TAG: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
delimiter = "\t" if any("\t" in line for line in lines) else ","
fails = []
content = ""
res = []
reader = csv.reader(lines, delimiter=delimiter)
for i, row in enumerate(reader):
if len(row) != 2:
content += "\n" + lines[i]
elif len(row) == 2:
content += "\n" + row[0]
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
content = ""
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract TAG : {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
raise NotImplementedError(
"Excel, csv(txt) format files are supported.")
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -59,13 +59,15 @@ class FulltextQueryer:
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", " ")
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
return txt
def question(self, txt, tbl="qa", min_match:float=0.6):
def question(self, txt, tbl="qa", min_match: float = 0.6):
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
@ -90,7 +92,8 @@ class FulltextQueryer:
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)]
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right:
@ -155,7 +158,7 @@ class FulltextQueryer:
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
if len(keywords) >= 32:
break
@ -174,8 +177,6 @@ class FulltextQueryer:
if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
if re.match(r"[0-9a-z ]+$", tt):
tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
syns = " OR ".join(
[
@ -232,3 +233,25 @@ class FulltextQueryer:
for k, v in qtwt.items():
q += v
return s / q
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str):
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
tks_w = self.tw.weights(content_tks, preprocess=False)
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if tk:
keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) / 10)})

View File

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from dataclasses import dataclass
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.utils import rmSpace
from rag.nlp import rag_tokenizer, query
import numpy as np
@ -47,7 +47,8 @@ class Dealer:
qv, _ = emb_mdl.encode_queries(txt)
shape = np.array(qv).shape
if len(shape) > 1:
raise Exception(f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
raise Exception(
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
embedding_data = [float(v) for v in qv]
vector_column_name = f"q_{len(embedding_data)}_vec"
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
@ -63,7 +64,12 @@ class Dealer:
condition[key] = req[key]
return condition
def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
def search(self, req, idx_names: str | list[str],
kb_ids: list[str],
emb_mdl=None,
highlight=False,
rank_feature: dict | None = None
):
filters = self.get_filters(req)
orderBy = OrderByExpr()
@ -72,9 +78,11 @@ class Dealer:
ps = int(req.get("size", topk))
offset, limit = pg * ps, ps
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd", "question_kwd", "question_tks",
"available_int", "content_with_weight", "pagerank_fea"])
src = req.get("fields",
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
"question_kwd", "question_tks",
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
kwds = set([])
qst = req.get("question", "")
@ -85,15 +93,16 @@ class Dealer:
orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"] if highlight else []
matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
@ -103,8 +112,9 @@ class Dealer:
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
@ -112,8 +122,9 @@ class Dealer:
matchText, _ = self.qryr.question(qst, min_match=0.1)
filters.pop("doc_ids", None)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res)
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
@ -126,8 +137,8 @@ class Dealer:
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids=self.dataStore.getChunkIds(res)
keywords=list(kwds)
ids = self.dataStore.getChunkIds(res)
keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
return self.SearchResult(
@ -188,13 +199,13 @@ class Dealer:
ans_v, _ = embd_mdl.encode(pieces_)
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
for ck in chunks]
cites = {}
thr = 0.63
while thr>0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
chunk_v,
@ -228,20 +239,44 @@ class Dealer:
return res, seted
def _rank_feature_scores(self, query_rfea, search_res):
## For rank feature(tag_fea) scores.
rank_fea = []
pageranks = []
for chunk_id in search_res.ids:
pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0))
pageranks = np.array(pageranks, dtype=float)
if not query_rfea:
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
if t in query_rfea:
nor += query_rfea[t] * sc
denor += sc * sc
if denor == 0:
rank_fea.append(0)
else:
rank_fea.append(nor/np.sqrt(denor)/q_denor)
return np.array(rank_fea)*10. + pageranks
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"):
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None
):
_, keywords = self.qryr.question(query)
vector_size = len(sres.query_vector)
vector_column = f"q_{vector_size}_vec"
zero_vector = [0.0] * vector_size
ins_embd = []
pageranks = []
for chunk_id in sres.ids:
vector = sres.field[chunk_id].get(vector_column, zero_vector)
if isinstance(vector, str):
vector = [float(v) for v in vector.split("\t")]
ins_embd.append(vector)
pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
if not ins_embd:
return [], [], []
@ -254,18 +289,22 @@ class Dealer:
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
important_kwd = sres.field[i].get("important_kwd", [])
tks = content_ltks + title_tks*2 + important_kwd*5 + question_tks*6
tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6
ins_tw.append(tks)
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
keywords,
ins_tw, tkweight, vtweight)
return sim+np.array(pageranks, dtype=float), tksim, vtsim
return sim + rank_fea, tksim, vtsim
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"):
vtweight=0.7, cfield="content_ltks",
rank_feature: dict | None = None):
_, keywords = self.qryr.question(query)
for i in sres.ids:
@ -280,9 +319,11 @@ class Dealer:
ins_tw.append(tks)
tksim = self.qryr.token_similarity(keywords, ins_tw)
vtsim,_ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,
@ -291,13 +332,15 @@ class Dealer:
rag_tokenizer.tokenize(inst).split())
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
rerank_mdl=None, highlight=False,
rank_feature: dict | None = {PAGERANK_FLD: 10}):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks
RERANK_PAGE_LIMIT = 3
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size * RERANK_PAGE_LIMIT, 128),
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}
@ -309,29 +352,30 @@ class Dealer:
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
sres = self.search(req, [index_name(tid) for tid in tenant_ids],
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
ranks["total"] = sres.total
if page <= RERANK_PAGE_LIMIT:
if rerank_mdl and sres.total > 0:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
sres, question, 1 - vector_similarity_weight,
vector_similarity_weight,
rank_feature=rank_feature)
else:
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size]
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
rank_feature=rank_feature)
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
else:
sim = tsim = vsim = [1]*len(sres.ids)
sim = tsim = vsim = [1] * len(sres.ids)
idx = list(range(len(sres.ids)))
def floor_sim(score):
return (int(score * 100.)%100)/100.
dim = len(sres.query_vector)
vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim
for i in idx:
if floor_sim(sim[i]) < similarity_threshold:
if sim[i] < similarity_threshold:
break
if len(ranks["chunks"]) >= page_size:
if aggs:
@ -369,8 +413,8 @@ class Dealer:
ranks["doc_aggs"] = [{"doc_name": k,
"doc_id": v["doc_id"],
"count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(),
key=lambda x:x[1]["count"] * -1)]
v in sorted(ranks["doc_aggs"].items(),
key=lambda x: x[1]["count"] * -1)]
ranks["chunks"] = ranks["chunks"][:page_size]
return ranks
@ -379,15 +423,57 @@ class Dealer:
tbl = self.dataStore.sql(sql, fetch_size, format)
return tbl
def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024,
offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"]):
condition = {"doc_id": doc_id}
res = []
bs = 128
for p in range(0, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), kb_ids)
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
if dict_chunks:
res.extend(dict_chunks.values())
if len(dict_chunks.values()) < bs:
break
return res
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd")
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd")
total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res}
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0}
return True
def tag_query(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000):
if isinstance(tenant_ids, str):
idx_nms = index_name(tenant_ids)
else:
idx_nms = [index_name(tid) for tid in tenant_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
key=lambda x: x[1] * -1)[:topn_tags]
return {a: c for a, c in tag_fea if c > 0}

View File

@ -38,6 +38,9 @@ SVR_QUEUE_RETENTION = 60*60
SVR_QUEUE_MAX_LEN = 1024
SVR_CONSUMER_NAME = "rag_flow_svr_consumer"
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group"
PAGERANK_FLD = "pagerank_fea"
TAG_FLD = "tag_feas"
def print_rag_settings():
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")

View File

@ -16,10 +16,10 @@
# from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import random
import sys
from api.utils.log_utils import initRootLogger
from graphrag.utils import get_llm_cache, set_llm_cache
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
@ -44,7 +44,7 @@ import numpy as np
from peewee import DoesNotExist
from api.db import LLMType, ParserType, TaskStatus
from api.db.services.dialog_service import keyword_extraction, question_proposal
from api.db.services.dialog_service import keyword_extraction, question_proposal, content_tagging
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
@ -53,10 +53,10 @@ from api import settings
from api.versions import get_ragflow_version
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email
knowledge_graph, email, tag
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL
@ -78,7 +78,8 @@ FACTORY = {
ParserType.ONE.value: one,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email,
ParserType.KG.value: knowledge_graph
ParserType.KG.value: knowledge_graph,
ParserType.TAG.value: tag
}
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
@ -199,7 +200,8 @@ def build_chunks(task, progress_callback):
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
raise
except Exception as e:
if re.search("(No such file|not found)", str(e)):
@ -227,7 +229,7 @@ def build_chunks(task, progress_callback):
"kb_id": str(task["kb_id"])
}
if task["pagerank"]:
doc["pagerank_fea"] = int(task["pagerank"])
doc[PAGERANK_FLD] = int(task["pagerank"])
el = 0
for ck in cks:
d = copy.deepcopy(doc)
@ -252,7 +254,8 @@ def build_chunks(task, progress_callback):
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st
except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise
d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
@ -295,12 +298,43 @@ def build_chunks(task, progress_callback):
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
progress_callback(msg="Start to tag for every chunk ...")
kb_ids = task["kb_parser_config"]["tag_kb_ids"]
tenant_id = task["tenant_id"]
topn_tags = task["kb_parser_config"].get("topn_tags", 3)
S = 1000
st = timer()
examples = []
all_tags = get_tags_from_cache(kb_ids)
if not all_tags:
all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags)
else:
all_tags = json.loads(all_tags)
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs:
if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
continue
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags,
random.choices(examples, k=2) if len(examples)>2 else examples,
topn=topn_tags)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
return docs
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size)
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
@ -381,7 +415,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
"title_tks": rag_tokenizer.tokenize(row["name"])
}
if row["pagerank"]:
doc["pagerank_fea"] = int(row["pagerank"])
doc[PAGERANK_FLD] = int(row["pagerank"])
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:
@ -480,7 +514,8 @@ def do_handle_task(task):
doc_store_result = ""
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
task_dataset_id)
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
@ -493,15 +528,21 @@ def do_handle_task(task):
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
task_dataset_id)
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks),
token_count, time_cost))
def handle_task():

View File

@ -71,11 +71,13 @@ def findMaxTm(fnm):
pass
return m
tiktoken_cache_dir = get_project_base_directory()
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
encoder = tiktoken.get_encoding("cl100k_base")
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
try:

View File

@ -176,7 +176,17 @@ class DocStoreConnection(ABC):
@abstractmethod
def search(
self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
) -> list[dict] | pl.DataFrame:
"""
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
@ -191,7 +201,7 @@ class DocStoreConnection(ABC):
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
"""
Update or insert a bulk of rows
"""

View File

@ -9,6 +9,7 @@ from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout
from rag import settings
from rag.settings import TAG_FLD, PAGERANK_FLD
from rag.utils import singleton
from api.utils.file_utils import get_project_base_directory
import polars as pl
@ -20,6 +21,7 @@ ATTEMPT_TIME = 2
logger = logging.getLogger('ragflow.es_conn')
@singleton
class ESConnection(DocStoreConnection):
def __init__(self):
@ -111,9 +113,19 @@ class ESConnection(DocStoreConnection):
CRUD operations
"""
def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr],
orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str],
knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
def search(
self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
) -> list[dict] | pl.DataFrame:
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
@ -175,8 +187,13 @@ class ESConnection(DocStoreConnection):
similarity=similarity,
)
if bqry and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry:
bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
s = s.query(bqry)
for field in highlightFields:
s = s.highlight(field)
@ -187,7 +204,7 @@ class ESConnection(DocStoreConnection):
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}
"mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
@ -195,8 +212,11 @@ class ESConnection(DocStoreConnection):
orders.append({field: order_info})
s = s.sort(*orders)
for fld in aggFields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset+limit]
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
@ -240,7 +260,7 @@ class ESConnection(DocStoreConnection):
logger.error("ESConnection.get timeout for 3 times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
@ -292,44 +312,57 @@ class ESConnection(DocStoreConnection):
if str(e).find("Timeout") > 0:
continue
return False
else:
# update unspecific maybe-multiple documents
bqry = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exist":
bqry.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
for k, v in newValue.items():
if k == "remove":
scripts.append(f"ctx._source.remove('{v}');")
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
# update unspecific maybe-multiple documents
bqry = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exist":
bqry.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in newValue.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.{k} = '{v}'")
elif isinstance(v, int):
scripts.append(f"ctx._source.{k} = {v}")
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
scripts.append(f"ctx._source.{k} = '{v}'")
elif isinstance(v, int):
scripts.append(f"ctx._source.{k} = {v}")
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
ubq = ubq.script(source="; ".join(scripts))
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True

View File

@ -10,6 +10,7 @@ from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode
from rag import settings
from rag.settings import PAGERANK_FLD
from rag.utils import singleton
import polars as pl
from polars.series.series import Series
@ -231,8 +232,7 @@ class InfinityConnection(DocStoreConnection):
"""
def search(
self,
selectFields: list[str],
self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
@ -241,7 +241,9 @@ class InfinityConnection(DocStoreConnection):
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
) -> tuple[pl.DataFrame, int]:
aggFields: list[str] = [],
rank_feature: dict | None = None
) -> list[dict] | pl.DataFrame:
"""
TODO: Infinity doesn't provide highlight
"""
@ -256,7 +258,7 @@ class InfinityConnection(DocStoreConnection):
if essential_field not in selectFields:
selectFields.append(essential_field)
if matchExprs:
for essential_field in ["score()", "pagerank_fea"]:
for essential_field in ["score()", PAGERANK_FLD]:
selectFields.append(essential_field)
# Prepare expressions common to all tables
@ -346,7 +348,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, selectFields)
if matchExprs:
res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True)
res = res.sort(pl.col("SCORE") + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
res = res.limit(limit)
logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
@ -378,7 +380,7 @@ class InfinityConnection(DocStoreConnection):
return res_fields.get(chunkId, None)
def insert(
self, documents: list[dict], indexName: str, knowledgebaseId: str
self, documents: list[dict], indexName: str, knowledgebaseId: str = None
) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
@ -456,7 +458,7 @@ class InfinityConnection(DocStoreConnection):
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
newValue[k] = "_".join(f"{num:08x}" for num in v)
elif k == "remove" and v in ["pagerank_fea"]:
elif k == "remove" and v in [PAGERANK_FLD]:
del newValue[k]
newValue[v] = 0
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")

View File

@ -27,7 +27,7 @@ def test_create_dataset_with_invalid_parameter(get_api_key_fixture):
API_KEY = get_api_key_fixture
rag = RAGFlow(API_KEY, HOST_ADDRESS)
valid_chunk_methods = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one",
"knowledge_graph", "email"]
"knowledge_graph", "email", "tag"]
chunk_method = "invalid_chunk_method"
with pytest.raises(Exception) as exc_info:
rag.create_dataset("test_create_dataset_with_invalid_chunk_method",chunk_method=chunk_method)