mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 04:49:05 +08:00
Add benchmark ndcg@10 (#2326)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
336a639164
commit
7241c73c7a
94
rag/benchmark.py
Normal file
94
rag/benchmark.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
|
from api.db import FileType, TaskStatus, ParserType, LLMType
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.settings import retrievaler
|
||||||
|
from api.utils import get_uuid
|
||||||
|
from rag.nlp import tokenize, search
|
||||||
|
from rag.utils.es_conn import ELASTICSEARCH
|
||||||
|
from ranx import evaluate
|
||||||
|
|
||||||
|
|
||||||
|
class benchmark_ndcg10:
|
||||||
|
def __init__(self, kb_id):
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
self.similarity_threshold = kb.similarity_threshold
|
||||||
|
self.vector_similarity_weight = kb.vector_similarity_weight
|
||||||
|
self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
||||||
|
|
||||||
|
def _get_benchmarks(self, query, count=16):
|
||||||
|
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
|
||||||
|
sres = retrievaler.search(req, search.index_name("benchmark"), self.embd_mdl)
|
||||||
|
return sres
|
||||||
|
|
||||||
|
def _get_retrieval(self, qrels):
|
||||||
|
run = defaultdict(dict)
|
||||||
|
query_list = list(qrels.keys())
|
||||||
|
for query in query_list:
|
||||||
|
sres = self._get_benchmarks(query)
|
||||||
|
sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight,
|
||||||
|
self.vector_similarity_weight)
|
||||||
|
for index, id in enumerate(sres.ids):
|
||||||
|
run[query][id] = sim[index]
|
||||||
|
return run
|
||||||
|
|
||||||
|
def embedding(self, docs, batch_size=16):
|
||||||
|
vects = []
|
||||||
|
cnts = [d["content_with_weight"] for d in docs]
|
||||||
|
for i in range(0, len(cnts), batch_size):
|
||||||
|
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
|
||||||
|
vects.extend(vts.tolist())
|
||||||
|
assert len(docs) == len(vects)
|
||||||
|
for i, d in enumerate(docs):
|
||||||
|
v = vects[i]
|
||||||
|
d["q_%d_vec" % len(v)] = v
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def __call__(self, file_path):
|
||||||
|
qrels = defaultdict(dict)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
with open(file_path) as f:
|
||||||
|
for line in f:
|
||||||
|
query, text, rel = line.strip('\n').split()
|
||||||
|
d = {
|
||||||
|
"id": get_uuid()
|
||||||
|
}
|
||||||
|
tokenize(d, text)
|
||||||
|
docs.append(d)
|
||||||
|
if len(docs) >= 32:
|
||||||
|
ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
|
||||||
|
docs = []
|
||||||
|
qrels[query][d["id"]] = float(rel)
|
||||||
|
docs = self.embedding(docs)
|
||||||
|
ELASTICSEARCH.bulk(docs, search.index_name("benchmark"))
|
||||||
|
|
||||||
|
run = self._get_retrieval(qrels)
|
||||||
|
return evaluate(qrels, run, "ndcg@10")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-f', '--filepath', default='', help="file path", action='store', required=True)
|
||||||
|
parser.add_argument('-k', '--kb_id', default='', help="kb_id", action='store', required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ex = benchmark_ndcg10(args.kb_id)
|
||||||
|
print(ex(args.filepath))
|
@ -70,6 +70,7 @@ python_dateutil==2.8.2
|
|||||||
python_pptx==0.6.23
|
python_pptx==0.6.23
|
||||||
pywencai==0.12.2
|
pywencai==0.12.2
|
||||||
qianfan==0.4.6
|
qianfan==0.4.6
|
||||||
|
ranx==0.3.20
|
||||||
readability_lxml==0.8.1
|
readability_lxml==0.8.1
|
||||||
redis==5.0.3
|
redis==5.0.3
|
||||||
Requests==2.32.2
|
Requests==2.32.2
|
||||||
|
@ -171,3 +171,4 @@ vertexai==1.64.0
|
|||||||
yfinance==0.2.43
|
yfinance==0.2.43
|
||||||
pywencai==0.12.2
|
pywencai==0.12.2
|
||||||
akshare==1.14.72
|
akshare==1.14.72
|
||||||
|
ranx==0.3.20
|
||||||
|
Loading…
x
Reference in New Issue
Block a user