mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 14:10:01 +08:00
fix: term weight issue (#3294)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
5205bdab24
commit
8b6e272197
@ -16,11 +16,15 @@
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.settings import retrievaler
|
||||
from api.utils import get_uuid
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.nlp import tokenize, search
|
||||
from rag.utils.es_conn import ELASTICSEARCH
|
||||
from ranx import evaluate
|
||||
@ -63,14 +67,34 @@ class Benchmark:
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def init_kb(index_name):
|
||||
idxnm = search.index_name(index_name)
|
||||
if ELASTICSEARCH.indexExist(idxnm):
|
||||
ELASTICSEARCH.deleteIdx(search.index_name(index_name))
|
||||
|
||||
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
qrels = defaultdict(dict)
|
||||
texts = defaultdict(dict)
|
||||
docs = []
|
||||
filelist = os.listdir(file_path)
|
||||
self.init_kb(index_name)
|
||||
|
||||
max_workers = int(os.environ.get('MAX_WORKERS', 3))
|
||||
exe = ThreadPoolExecutor(max_workers=max_workers)
|
||||
threads = []
|
||||
|
||||
def slow_actions(es_docs, idx_nm):
|
||||
es_docs = self.embedding(es_docs)
|
||||
ELASTICSEARCH.bulk(es_docs, idx_nm)
|
||||
return True
|
||||
|
||||
for dir in filelist:
|
||||
data = pd.read_parquet(os.path.join(file_path, dir))
|
||||
for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir):
|
||||
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
|
||||
|
||||
query = data.iloc[i]['query']
|
||||
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
||||
@ -82,12 +106,17 @@ class Benchmark:
|
||||
texts[d["id"]] = text
|
||||
qrels[query][d["id"]] = int(rel)
|
||||
if len(docs) >= 32:
|
||||
docs = self.embedding(docs)
|
||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
||||
threads.append(
|
||||
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
||||
docs = []
|
||||
|
||||
docs = self.embedding(docs)
|
||||
ELASTICSEARCH.bulk(docs, search.index_name(index_name))
|
||||
threads.append(
|
||||
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
||||
|
||||
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
|
||||
if not threads[i].result().output:
|
||||
print("Indexing error...")
|
||||
|
||||
return qrels, texts
|
||||
|
||||
def trivia_qa_index(self, file_path, index_name):
|
||||
|
@ -227,7 +227,7 @@ class Dealer:
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tks])
|
||||
tw = zip(tks, wts)
|
||||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
|
Loading…
x
Reference in New Issue
Block a user