fix: term weight issue (#3294)

### What problem does this PR solve?



### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu 2024-11-08 15:49:44 +08:00 committed by GitHub
parent 5205bdab24
commit 8b6e272197
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 6 deletions

View File

@ -16,11 +16,15 @@
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import retrievaler from api.settings import retrievaler
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory
from rag.nlp import tokenize, search from rag.nlp import tokenize, search
from rag.utils.es_conn import ELASTICSEARCH from rag.utils.es_conn import ELASTICSEARCH
from ranx import evaluate from ranx import evaluate
@ -63,14 +67,34 @@ class Benchmark:
d["q_%d_vec" % len(v)] = v d["q_%d_vec" % len(v)] = v
return docs return docs
@staticmethod
def init_kb(index_name):
idxnm = search.index_name(index_name)
if ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.deleteIdx(search.index_name(index_name))
return ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
def ms_marco_index(self, file_path, index_name): def ms_marco_index(self, file_path, index_name):
qrels = defaultdict(dict) qrels = defaultdict(dict)
texts = defaultdict(dict) texts = defaultdict(dict)
docs = [] docs = []
filelist = os.listdir(file_path) filelist = os.listdir(file_path)
self.init_kb(index_name)
max_workers = int(os.environ.get('MAX_WORKERS', 3))
exe = ThreadPoolExecutor(max_workers=max_workers)
threads = []
def slow_actions(es_docs, idx_nm):
es_docs = self.embedding(es_docs)
ELASTICSEARCH.bulk(es_docs, idx_nm)
return True
for dir in filelist: for dir in filelist:
data = pd.read_parquet(os.path.join(file_path, dir)) data = pd.read_parquet(os.path.join(file_path, dir))
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir): for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
query = data.iloc[i]['query'] query = data.iloc[i]['query']
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']): for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
@ -82,12 +106,17 @@ class Benchmark:
texts[d["id"]] = text texts[d["id"]] = text
qrels[query][d["id"]] = int(rel) qrels[query][d["id"]] = int(rel)
if len(docs) >= 32: if len(docs) >= 32:
docs = self.embedding(docs) threads.append(
ELASTICSEARCH.bulk(docs, search.index_name(index_name)) exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
docs = [] docs = []
docs = self.embedding(docs) threads.append(
ELASTICSEARCH.bulk(docs, search.index_name(index_name)) exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
if not threads[i].result().output:
print("Indexing error...")
return qrels, texts return qrels, texts
def trivia_qa_index(self, file_path, index_name): def trivia_qa_index(self, file_path, index_name):

View File

@ -227,7 +227,7 @@ class Dealer:
idf2 = np.array([idf(df(t), 1000000000) for t in tks]) idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \ wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tks]) np.array([ner(t) * postag(t) for t in tks])
tw = zip(tks, wts) tw = list(zip(tks, wts))
else: else:
for tk in tks: for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True)) tt = self.tokenMerge(self.pretoken(tk, True))