Run keyword_extraction, question_proposal, content_tagging in thread pool (#5376)

### What problem does this PR solve?

Run keyword_extraction, question_proposal, content_tagging in threads

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu 2025-02-26 15:21:14 +08:00 committed by GitHub
parent 5859a3df72
commit ffb4cda475
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3069 additions and 3066 deletions

View File

@ -28,6 +28,7 @@ CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME)
import asyncio
import logging
import os
from datetime import datetime
@ -300,36 +301,36 @@ def build_chunks(task, progress_callback):
st = timer()
progress_callback(msg="Start to generate keywords for every chunk ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs:
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords",
{"topn": task["parser_config"]["auto_keywords"]})
if not cached:
cached = keyword_extraction(chat_mdl, d["content_with_weight"],
task["parser_config"]["auto_keywords"])
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords",
{"topn": task["parser_config"]["auto_keywords"]})
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
async def doc_keyword_extraction(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached:
cached = await asyncio.to_thread(keyword_extraction, chat_mdl, d["content_with_weight"], topn)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached:
d["important_kwd"] = cached.split(",")
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
tasks = [doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]) for d in docs]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0):
st = timer()
progress_callback(msg="Start to generate questions for every chunk ...")
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
for d in docs:
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question",
{"topn": task["parser_config"]["auto_questions"]})
if not cached:
cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"])
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question",
{"topn": task["parser_config"]["auto_questions"]})
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached:
cached = await asyncio.to_thread(question_proposal, chat_mdl, d["content_with_weight"], topn)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = [doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]) for d in docs]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
progress_callback(msg="Start to tag for every chunk ...")
@ -347,22 +348,27 @@ def build_chunks(task, progress_callback):
all_tags = json.loads(all_tags)
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
docs_to_tag = []
for d in docs:
if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
continue
else:
docs_to_tag.append(d)
async def doc_content_tagging(chat_mdl, d, topn_tags):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags,
random.choices(examples, k=2) if len(examples)>2 else examples,
topn=topn_tags)
picked_examples = random.choices(examples, k=2) if len(examples)>2 else examples
cached = await asyncio.to_thread(content_tagging, chat_mdl, d["content_with_weight"], all_tags, picked_examples, topn=topn_tags)
if cached:
cached = json.dumps(cached)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
tasks = [doc_content_tagging(chat_mdl, d, topn_tags) for d in docs_to_tag]
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs

6071
uv.lock generated

File diff suppressed because it is too large Load Diff