fix: term weight issue (#3294)

### What problem does this PR solve?



### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu 2024-11-08 15:49:44 +08:00 committed by GitHub
parent 5205bdab24
commit 8b6e272197
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 6 deletions

View File

@ -16,11 +16,15 @@
import json
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import retrievaler
from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory
from rag.nlp import tokenize, search
from rag.utils.es_conn import ELASTICSEARCH
from ranx import evaluate
@ -63,14 +67,34 @@ class Benchmark:
d["q_%d_vec" % len(v)] = v
return docs
@staticmethod
def init_kb(index_name):
idxnm = search.index_name(index_name)
if ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.deleteIdx(search.index_name(index_name))
return ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
def ms_marco_index(self, file_path, index_name):
qrels = defaultdict(dict)
texts = defaultdict(dict)
docs = []
filelist = os.listdir(file_path)
self.init_kb(index_name)
max_workers = int(os.environ.get('MAX_WORKERS', 3))
exe = ThreadPoolExecutor(max_workers=max_workers)
threads = []
def slow_actions(es_docs, idx_nm):
es_docs = self.embedding(es_docs)
ELASTICSEARCH.bulk(es_docs, idx_nm)
return True
for dir in filelist:
data = pd.read_parquet(os.path.join(file_path, dir))
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir):
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
query = data.iloc[i]['query']
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
@ -82,12 +106,17 @@ class Benchmark:
texts[d["id"]] = text
qrels[query][d["id"]] = int(rel)
if len(docs) >= 32:
docs = self.embedding(docs)
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
threads.append(
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
docs = []
docs = self.embedding(docs)
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
threads.append(
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
if not threads[i].result().output:
print("Indexing error...")
return qrels, texts
def trivia_qa_index(self, file_path, index_name):

View File

@ -227,7 +227,7 @@ class Dealer:
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tks])
tw = zip(tks, wts)
tw = list(zip(tks, wts))
else:
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))