From e4bb943fe5d5d6f096e02c76bab9eee34856282b Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 24 Jul 2024 12:50:11 +0800 Subject: [PATCH] Feat/delete single dataset retrival (#6570) --- .../easy_ui_based_app/dataset/manager.py | 12 +- api/core/app/app_config/entities.py | 4 + .../data_post_processor.py | 53 +++-- .../jieba/jieba_keyword_table_handler.py | 3 +- api/core/rag/datasource/retrieval_service.py | 20 +- .../datasource/vdb/qdrant/qdrant_vector.py | 6 +- api/core/rag/docstore/__init__.py | 0 api/core/rag/rerank/constants/rerank_mode.py | 8 + api/core/rag/rerank/entity/weight.py | 23 ++ .../rag/rerank/{rerank.py => rerank_model.py} | 2 +- api/core/rag/rerank/weight_rerank.py | 178 +++++++++++++++ api/core/rag/retrieval/dataset_retrieval.py | 146 ++++++++++-- api/core/rag/splitter/__init__.py | 0 .../dataset_multi_retriever_tool.py | 7 +- .../dataset_retriever_tool.py | 3 +- .../nodes/knowledge_retrieval/entities.py | 28 +++ .../knowledge_retrieval_node.py | 29 ++- api/fields/dataset_fields.py | 18 ++ api/models/model.py | 4 +- api/poetry.lock | 209 +++++++++++++----- api/pyproject.toml | 5 +- api/services/hit_testing_service.py | 8 +- 22 files changed, 651 insertions(+), 115 deletions(-) create mode 100644 api/core/rag/docstore/__init__.py create mode 100644 api/core/rag/rerank/constants/rerank_mode.py create mode 100644 api/core/rag/rerank/entity/weight.py rename api/core/rag/rerank/{rerank.py => rerank_model.py} (98%) create mode 100644 api/core/rag/rerank/weight_rerank.py create mode 100644 api/core/rag/splitter/__init__.py diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index c10aa98dba..7d2381d958 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -62,7 +62,12 @@ class DatasetConfigManager: return None # dataset configs - dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'}) + if 'dataset_configs' in config and config.get('dataset_configs'): + dataset_configs = config.get('dataset_configs') + else: + dataset_configs = { + 'retrieval_model': 'multiple' + } query_variable = config.get('dataset_query_variable') if dataset_configs['retrieval_model'] == 'single': @@ -83,9 +88,10 @@ class DatasetConfigManager: retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( dataset_configs['retrieval_model'] ), - top_k=dataset_configs.get('top_k'), + top_k=dataset_configs.get('top_k', 4), score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model') + reranking_model=dataset_configs.get('reranking_model'), + weights=dataset_configs.get('weights') ) ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 9b7012c3fb..9133a35c08 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -159,7 +159,11 @@ class DatasetRetrieveConfigEntity(BaseModel): retrieve_strategy: RetrieveStrategy top_k: Optional[int] = None score_threshold: Optional[float] = None + rerank_mode: Optional[str] = 'reranking_model' reranking_model: Optional[dict] = None + weights: Optional[dict] = None + + class DatasetEntity(BaseModel): diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index a0f2947784..2ed6d74187 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -5,15 +5,20 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.models.document import Document -from core.rag.rerank.rerank import RerankRunner +from core.rag.rerank.constants.rerank_mode import RerankMode +from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.rerank.weight_rerank import WeightRerankRunner class DataPostProcessor: """Interface for data post-processing document. """ - def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False): - self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id) + def __init__(self, tenant_id: str, reranking_mode: str, + reranking_model: Optional[dict] = None, weights: Optional[dict] = None, + reorder_enabled: bool = False): + self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) self.reorder_runner = self._get_reorder_runner(reorder_enabled) def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, @@ -26,19 +31,37 @@ class DataPostProcessor: return documents - def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]: - if reranking_model: - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - provider=reranking_model['reranking_provider_name'], - model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] + def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None, + weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]: + if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: + return WeightRerankRunner( + tenant_id, + Weights( + weight_type=weights['weight_type'], + vector_setting=VectorSetting( + vector_weight=weights['vector_setting']['vector_weight'], + embedding_provider_name=weights['vector_setting']['embedding_provider_name'], + embedding_model_name=weights['vector_setting']['embedding_model_name'], + ), + keyword_setting=KeywordSetting( + keyword_weight=weights['keyword_setting']['keyword_weight'], + ) ) - except InvokeAuthorizationError: - return None - return RerankRunner(rerank_model_instance) + ) + elif reranking_mode == RerankMode.RERANKING_MODEL.value: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model['reranking_provider_name'], + model_type=ModelType.RERANK, + model=reranking_model['reranking_model_name'] + ) + except InvokeAuthorizationError: + return None + return RerankModelRunner(rerank_model_instance) + return None return None def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 5f862b8d18..ad669ef515 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,4 +1,5 @@ import re +from typing import Optional import jieba from jieba.analyse import default_tfidf @@ -11,7 +12,7 @@ class JiebaKeywordTableHandler: def __init__(self): default_tfidf.stop_words = STOPWORDS - def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]: + def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" keywords = jieba.analyse.extract_tags( sentence=text, diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 702dcec314..abbf4a35a4 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,6 +6,7 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.retrieval.retrival_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset @@ -26,13 +27,19 @@ class RetrievalService: @classmethod def retrieve(cls, retrival_method: str, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): + top_k: int, score_threshold: Optional[float] = .0, + reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None, + weights: Optional[dict] = None): dataset = db.session.query(Dataset).filter( Dataset.id == dataset_id ).first() if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] all_documents = [] + keyword_search_documents = [] + embedding_search_documents = [] + full_text_search_documents = [] + hybrid_search_documents = [] threads = [] exceptions = [] # retrieval_model source with keyword @@ -87,7 +94,8 @@ class RetrievalService: raise Exception(exception_message) if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, + reranking_model, weights, False) all_documents = data_post_processor.invoke( query=query, documents=all_documents, @@ -143,7 +151,9 @@ class RetrievalService: if documents: if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + data_post_processor = DataPostProcessor(str(dataset.tenant_id), + RerankMode.RERANKING_MODEL.value, + reranking_model, None, False) all_documents.extend(data_post_processor.invoke( query=query, documents=documents, @@ -175,7 +185,9 @@ class RetrievalService: ) if documents: if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) + data_post_processor = DataPostProcessor(str(dataset.tenant_id), + RerankMode.RERANKING_MODEL.value, + reranking_model, None, False) all_documents.extend(data_post_processor.invoke( query=query, documents=documents, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f9a9389868..77c3f6a271 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -396,9 +396,11 @@ class QdrantVector(BaseVector): documents = [] for result in results: if result: - documents.append(self._document_from_scored_point( + document = self._document_from_scored_point( result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value - )) + ) + document.metadata['vector'] = result.vector + documents.append(document) return documents diff --git a/api/core/rag/docstore/__init__.py b/api/core/rag/docstore/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/constants/rerank_mode.py new file mode 100644 index 0000000000..afbb9fd89d --- /dev/null +++ b/api/core/rag/rerank/constants/rerank_mode.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class RerankMode(Enum): + + RERANKING_MODEL = 'reranking_model' + WEIGHTED_SCORE = 'weighted_score' + diff --git a/api/core/rag/rerank/entity/weight.py b/api/core/rag/rerank/entity/weight.py new file mode 100644 index 0000000000..36afc89a21 --- /dev/null +++ b/api/core/rag/rerank/entity/weight.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel + + +class VectorSetting(BaseModel): + vector_weight: float + + embedding_provider_name: str + + embedding_model_name: str + + +class KeywordSetting(BaseModel): + keyword_weight: float + + +class Weights(BaseModel): + """Model for weighted rerank.""" + + weight_type: str + + vector_setting: VectorSetting + + keyword_setting: KeywordSetting diff --git a/api/core/rag/rerank/rerank.py b/api/core/rag/rerank/rerank_model.py similarity index 98% rename from api/core/rag/rerank/rerank.py rename to api/core/rag/rerank/rerank_model.py index 7000f4e0ad..d9067da288 100644 --- a/api/core/rag/rerank/rerank.py +++ b/api/core/rag/rerank/rerank_model.py @@ -4,7 +4,7 @@ from core.model_manager import ModelInstance from core.rag.models.document import Document -class RerankRunner: +class RerankModelRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py new file mode 100644 index 0000000000..d07f94adb7 --- /dev/null +++ b/api/core/rag/rerank/weight_rerank.py @@ -0,0 +1,178 @@ +import math +from collections import Counter +from typing import Optional + +import numpy as np + +from core.embedding.cached_embedding import CacheEmbedding +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.models.document import Document +from core.rag.rerank.entity.weight import VectorSetting, Weights + + +class WeightRerankRunner: + + def __init__(self, tenant_id: str, weights: Weights) -> None: + self.tenant_id = tenant_id + self.weights = weights + + def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, + top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + """ + Run rerank model + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + + :return: + """ + docs = [] + doc_id = [] + unique_documents = [] + for document in documents: + if document.metadata['doc_id'] not in doc_id: + doc_id.append(document.metadata['doc_id']) + docs.append(document.page_content) + unique_documents.append(document) + + documents = unique_documents + + rerank_documents = [] + query_scores = self._calculate_keyword_score(query, documents) + + query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting) + for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores): + # format document + score = self.weights.vector_setting.vector_weight * query_vector_score + \ + self.weights.keyword_setting.keyword_weight * query_score + if score_threshold and score < score_threshold: + continue + document.metadata['score'] = score + rerank_documents.append(document) + rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True) + return rerank_documents[:top_n] if top_n else rerank_documents + + def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: + """ + Calculate BM25 scores + :param query: search query + :param documents: documents for reranking + + :return: + """ + keyword_table_handler = JiebaKeywordTableHandler() + query_keywords = keyword_table_handler.extract_keywords(query, None) + documents_keywords = [] + for document in documents: + # get the document keywords + document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) + document.metadata['keywords'] = document_keywords + documents_keywords.append(document_keywords) + + # Counter query keywords(TF) + query_keyword_counts = Counter(query_keywords) + + # total documents + total_documents = len(documents) + + # calculate all documents' keywords IDF + all_keywords = set() + for document_keywords in documents_keywords: + all_keywords.update(document_keywords) + + keyword_idf = {} + for keyword in all_keywords: + # calculate include query keywords' documents + doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords) + # IDF + keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1 + + query_tfidf = {} + + for keyword, count in query_keyword_counts.items(): + tf = count + idf = keyword_idf.get(keyword, 0) + query_tfidf[keyword] = tf * idf + + # calculate all documents' TF-IDF + documents_tfidf = [] + for document_keywords in documents_keywords: + document_keyword_counts = Counter(document_keywords) + document_tfidf = {} + for keyword, count in document_keyword_counts.items(): + tf = count + idf = keyword_idf.get(keyword, 0) + document_tfidf[keyword] = tf * idf + documents_tfidf.append(document_tfidf) + + def cosine_similarity(vec1, vec2): + intersection = set(vec1.keys()) & set(vec2.keys()) + numerator = sum(vec1[x] * vec2[x] for x in intersection) + + sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) + sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + denominator = math.sqrt(sum1) * math.sqrt(sum2) + + if not denominator: + return 0.0 + else: + return float(numerator) / denominator + + similarities = [] + for document_tfidf in documents_tfidf: + similarity = cosine_similarity(query_tfidf, document_tfidf) + similarities.append(similarity) + + # for idx, similarity in enumerate(similarities): + # print(f"Document {idx + 1} similarity: {similarity}") + + return similarities + + def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document], + vector_setting: VectorSetting) -> list[float]: + """ + Calculate Cosine scores + :param query: search query + :param documents: documents for reranking + + :return: + """ + query_vector_scores = [] + + model_manager = ModelManager() + + embedding_model = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=vector_setting.embedding_provider_name, + model_type=ModelType.TEXT_EMBEDDING, + model=vector_setting.embedding_model_name + + ) + cache_embedding = CacheEmbedding(embedding_model) + query_vector = cache_embedding.embed_query(query) + for document in documents: + # calculate cosine similarity + if 'score' in document.metadata: + query_vector_scores.append(document.metadata['score']) + else: + content_vector = document.metadata['vector'] + # transform to NumPy + vec1 = np.array(query_vector) + vec2 = np.array(document.metadata['vector']) + + # calculate dot product + dot_product = np.dot(vec1, vec2) + + # calculate norm + norm_vec1 = np.linalg.norm(vec1) + norm_vec2 = np.linalg.norm(vec2) + + # calculate cosine similarity + cosine_sim = dot_product / (norm_vec1 * norm_vec2) + query_vector_scores.append(cosine_sim) + + return query_vector_scores diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c1f5e0820c..d51ea2942a 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,4 +1,6 @@ +import math import threading +from collections import Counter from typing import Optional, cast from flask import Flask, current_app @@ -14,9 +16,10 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName from core.ops.utils import measure_time +from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document -from core.rag.rerank.rerank import RerankRunner from core.rag.retrieval.retrival_methods import RetrievalMethod from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter @@ -132,8 +135,9 @@ class DatasetRetrieval: app_id, tenant_id, user_id, user_from, available_datasets, query, retrieve_config.top_k, retrieve_config.score_threshold, - retrieve_config.reranking_model.get('reranking_provider_name'), - retrieve_config.reranking_model.get('reranking_model_name'), + retrieve_config.rerank_mode, + retrieve_config.reranking_model, + retrieve_config.weights, message_id, ) @@ -272,7 +276,8 @@ class DatasetRetrieval: retrival_method=retrival_method, dataset_id=dataset.id, query=query, top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model + reranking_model=reranking_model, + weights=retrieval_model_config.get('weights', None), ) self._on_query(query, [dataset_id], app_id, user_from, user_id) @@ -292,14 +297,18 @@ class DatasetRetrieval: query: str, top_k: int, score_threshold: float, - reranking_provider_name: str, - reranking_model_name: str, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reranking_enable: bool = True, message_id: Optional[str] = None, ): threads = [] all_documents = [] dataset_ids = [dataset.id for dataset in available_datasets] + index_type = None for dataset in available_datasets: + index_type = dataset.indexing_technique retrieval_thread = threading.Thread(target=self._retriever, kwargs={ 'flask_app': current_app._get_current_object(), 'dataset_id': dataset.id, @@ -311,23 +320,24 @@ class DatasetRetrieval: retrieval_thread.start() for thread in threads: thread.join() - # do rerank for searched documents - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - provider=reranking_provider_name, - model_type=ModelType.RERANK, - model=reranking_model_name - ) - rerank_runner = RerankRunner(rerank_model_instance) + if reranking_enable: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, + reranking_model, weights, False) - with measure_time() as timer: - all_documents = rerank_runner.run( - query, all_documents, - score_threshold, - top_k - ) + with measure_time() as timer: + all_documents = data_post_processor.invoke( + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=top_k + ) + else: + if index_type == "economy": + all_documents = self.calculate_keyword_score(query, all_documents, top_k) + elif index_type == "high_quality": + all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) self._on_query(query, dataset_ids, app_id, user_from, user_id) if all_documents: @@ -420,7 +430,8 @@ class DatasetRetrieval: score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None + if retrieval_model['reranking_enable'] else None, + weights=retrieval_model.get('weights', None), ) all_documents.extend(documents) @@ -513,3 +524,94 @@ class DatasetRetrieval: tools.append(tool) return tools + + def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]: + """ + Calculate keywords scores + :param query: search query + :param documents: documents for reranking + + :return: + """ + keyword_table_handler = JiebaKeywordTableHandler() + query_keywords = keyword_table_handler.extract_keywords(query, None) + documents_keywords = [] + for document in documents: + # get the document keywords + document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) + document.metadata['keywords'] = document_keywords + documents_keywords.append(document_keywords) + + # Counter query keywords(TF) + query_keyword_counts = Counter(query_keywords) + + # total documents + total_documents = len(documents) + + # calculate all documents' keywords IDF + all_keywords = set() + for document_keywords in documents_keywords: + all_keywords.update(document_keywords) + + keyword_idf = {} + for keyword in all_keywords: + # calculate include query keywords' documents + doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords) + # IDF + keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1 + + query_tfidf = {} + + for keyword, count in query_keyword_counts.items(): + tf = count + idf = keyword_idf.get(keyword, 0) + query_tfidf[keyword] = tf * idf + + # calculate all documents' TF-IDF + documents_tfidf = [] + for document_keywords in documents_keywords: + document_keyword_counts = Counter(document_keywords) + document_tfidf = {} + for keyword, count in document_keyword_counts.items(): + tf = count + idf = keyword_idf.get(keyword, 0) + document_tfidf[keyword] = tf * idf + documents_tfidf.append(document_tfidf) + + def cosine_similarity(vec1, vec2): + intersection = set(vec1.keys()) & set(vec2.keys()) + numerator = sum(vec1[x] * vec2[x] for x in intersection) + + sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) + sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + denominator = math.sqrt(sum1) * math.sqrt(sum2) + + if not denominator: + return 0.0 + else: + return float(numerator) / denominator + + similarities = [] + for document_tfidf in documents_tfidf: + similarity = cosine_similarity(query_tfidf, document_tfidf) + similarities.append(similarity) + + for document, score in zip(documents, similarities): + # format document + document.metadata['score'] = score + documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True) + return documents[:top_k] if top_k else documents + + def calculate_vector_score(self, all_documents: list[Document], + top_k: int, score_threshold: float) -> list[Document]: + filter_documents = [] + for document in all_documents: + if document.metadata['score'] >= score_threshold: + filter_documents.append(document) + if not filter_documents: + return [] + filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) + return filter_documents[:top_k] if top_k else filter_documents + + + diff --git a/api/core/rag/splitter/__init__.py b/api/core/rag/splitter/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index eaf58ed5bd..1a0933af16 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -7,7 +7,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.rerank.rerank import RerankRunner +from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrival_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db @@ -72,7 +72,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): model=self.reranking_model_name ) - rerank_runner = RerankRunner(rerank_model_instance) + rerank_runner = RerankModelRunner(rerank_model_instance) all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) for hit_callback in self.hit_callbacks: @@ -180,7 +180,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None + if retrieval_model['reranking_enable'] else None, + weights=retrieval_model.get('weights', None), ) all_documents.extend(documents) \ No newline at end of file diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index b1e541b8db..397ff7966e 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None + if retrieval_model['reranking_enable'] else None, + weights=retrieval_model.get('weights', None), ) else: documents = [] diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index c8874ff22c..5a9a4a9009 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -13,13 +13,41 @@ class RerankingModelConfig(BaseModel): model: str +class VectorSetting(BaseModel): + """ + Vector Setting. + """ + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """ + Keyword Setting. + """ + keyword_weight: float + + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + weight_type: str + vector_setting: VectorSetting + keyword_setting: KeywordSetting + + class MultipleRetrievalConfig(BaseModel): """ Multiple Retrieval Config. """ top_k: int score_threshold: Optional[float] = None + reranking_mode: str = 'reranking_model' + reranking_enable: bool = True reranking_model: RerankingModelConfig + weights: WeightedScoreConfig class ModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index ccd45d1383..5f8966f880 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -138,13 +138,38 @@ class KnowledgeRetrievalNode(BaseNode): planning_strategy=planning_strategy ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model': + reranking_model = { + 'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model['provider'], + 'reranking_model_name': node_data.multiple_retrieval_config.reranking_model['name'] + } + weights = None + elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score': + reranking_model = None + weights = { + 'weight_type': node_data.multiple_retrieval_config.weights.weight_type, + 'vector_setting': { + "vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight, + "embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name, + "embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name, + }, + 'keyword_setting': { + "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight + } + } + else: + reranking_model = None + weights = None all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, self.user_from.value, available_datasets, query, node_data.multiple_retrieval_config.top_k, node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.reranking_model.provider, - node_data.multiple_retrieval_config.reranking_model.model) + node_data.multiple_retrieval_config.reranking_mode, + reranking_model, + weights, + node_data.multiple_retrieval_config.reranking_enable, + ) context_list = [] if all_documents: diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 50c5f43540..120b66a92d 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -18,10 +18,28 @@ reranking_model_fields = { 'reranking_model_name': fields.String } +keyword_setting_fields = { + 'keyword_weight': fields.Float +} + +vector_setting_fields = { + 'vector_weight': fields.Float, + 'embedding_model_name': fields.String, + 'embedding_provider_name': fields.String, +} + +weighted_score_fields = { + 'weight_type': fields.String, + 'keyword_setting': fields.Nested(keyword_setting_fields), + 'vector_setting': fields.Nested(vector_setting_fields), +} + dataset_retrieval_model_fields = { 'search_method': fields.String, 'reranking_enable': fields.Boolean, + 'reranking_mode': fields.String, 'reranking_model': fields.Nested(reranking_model_fields), + 'weights': fields.Nested(weighted_score_fields, allow_null=True), 'top_k': fields.Integer, 'score_threshold_enabled': fields.Boolean, 'score_threshold': fields.Float diff --git a/api/models/model.py b/api/models/model.py index 396cd7ec63..a6f517ea6b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -328,7 +328,9 @@ class AppModelConfig(db.Model): return {'retrieval_model': 'single'} else: return dataset_configs - return {'retrieval_model': 'single'} + return { + 'retrieval_model': 'multiple', + } @property def file_upload_dict(self) -> dict: diff --git a/api/poetry.lock b/api/poetry.lock index e396920e85..2a277dac2d 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1242,40 +1242,36 @@ files = [ [[package]] name = "chroma-hnswlib" -version = "0.7.6" +version = "0.7.3" description = "Chromas fork of hnswlib" optional = false python-versions = "*" files = [ - {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f35192fbbeadc8c0633f0a69c3d3e9f1a4eab3a46b65458bbcbcabdd9e895c36"}, - {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f007b608c96362b8f0c8b6b2ac94f67f83fcbabd857c378ae82007ec92f4d82"}, - {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:456fd88fa0d14e6b385358515aef69fc89b3c2191706fd9aee62087b62aad09c"}, - {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5dfaae825499c2beaa3b75a12d7ec713b64226df72a5c4097203e3ed532680da"}, - {file = "chroma_hnswlib-0.7.6-cp310-cp310-win_amd64.whl", hash = "sha256:2487201982241fb1581be26524145092c95902cb09fc2646ccfbc407de3328ec"}, - {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81181d54a2b1e4727369486a631f977ffc53c5533d26e3d366dda243fb0998ca"}, - {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4b4ab4e11f1083dd0a11ee4f0e0b183ca9f0f2ed63ededba1935b13ce2b3606f"}, - {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53db45cd9173d95b4b0bdccb4dbff4c54a42b51420599c32267f3abbeb795170"}, - {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c093f07a010b499c00a15bc9376036ee4800d335360570b14f7fe92badcdcf9"}, - {file = "chroma_hnswlib-0.7.6-cp311-cp311-win_amd64.whl", hash = "sha256:0540b0ac96e47d0aa39e88ea4714358ae05d64bbe6bf33c52f316c664190a6a3"}, - {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e87e9b616c281bfbe748d01705817c71211613c3b063021f7ed5e47173556cb7"}, - {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ec5ca25bc7b66d2ecbf14502b5729cde25f70945d22f2aaf523c2d747ea68912"}, - {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:305ae491de9d5f3c51e8bd52d84fdf2545a4a2bc7af49765cda286b7bb30b1d4"}, - {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:822ede968d25a2c88823ca078a58f92c9b5c4142e38c7c8b4c48178894a0a3c5"}, - {file = "chroma_hnswlib-0.7.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2fe6ea949047beed19a94b33f41fe882a691e58b70c55fdaa90274ae78be046f"}, - {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feceff971e2a2728c9ddd862a9dd6eb9f638377ad98438876c9aeac96c9482f5"}, - {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb0633b60e00a2b92314d0bf5bbc0da3d3320be72c7e3f4a9b19f4609dc2b2ab"}, - {file = "chroma_hnswlib-0.7.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a566abe32fab42291f766d667bdbfa234a7f457dcbd2ba19948b7a978c8ca624"}, - {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6be47853d9a58dedcfa90fc846af202b071f028bbafe1d8711bf64fe5a7f6111"}, - {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a7af35bdd39a88bffa49f9bb4bf4f9040b684514a024435a1ef5cdff980579d"}, - {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a53b1f1551f2b5ad94eb610207bde1bb476245fc5097a2bec2b476c653c58bde"}, - {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3085402958dbdc9ff5626ae58d696948e715aef88c86d1e3f9285a88f1afd3bc"}, - {file = "chroma_hnswlib-0.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:77326f658a15adfb806a16543f7db7c45f06fd787d699e643642d6bde8ed49c4"}, - {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:93b056ab4e25adab861dfef21e1d2a2756b18be5bc9c292aa252fa12bb44e6ae"}, - {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fe91f018b30452c16c811fd6c8ede01f84e5a9f3c23e0758775e57f1c3778871"}, - {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6c0e627476f0f4d9e153420d36042dd9c6c3671cfd1fe511c0253e38c2a1039"}, - {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e9796a4536b7de6c6d76a792ba03e08f5aaa53e97e052709568e50b4d20c04f"}, - {file = "chroma_hnswlib-0.7.6-cp39-cp39-win_amd64.whl", hash = "sha256:d30e2db08e7ffdcc415bd072883a322de5995eb6ec28a8f8c054103bbd3ec1e0"}, - {file = "chroma_hnswlib-0.7.6.tar.gz", hash = "sha256:4dce282543039681160259d29fcde6151cc9106c6461e0485f57cdccd83059b7"}, + {file = "chroma-hnswlib-0.7.3.tar.gz", hash = "sha256:b6137bedde49fffda6af93b0297fe00429fc61e5a072b1ed9377f909ed95a932"}, + {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59d6a7c6f863c67aeb23e79a64001d537060b6995c3eca9a06e349ff7b0998ca"}, + {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d71a3f4f232f537b6152947006bd32bc1629a8686df22fd97777b70f416c127a"}, + {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c92dc1ebe062188e53970ba13f6b07e0ae32e64c9770eb7f7ffa83f149d4210"}, + {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49da700a6656fed8753f68d44b8cc8ae46efc99fc8a22a6d970dc1697f49b403"}, + {file = "chroma_hnswlib-0.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:108bc4c293d819b56476d8f7865803cb03afd6ca128a2a04d678fffc139af029"}, + {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:11e7ca93fb8192214ac2b9c0943641ac0daf8f9d4591bb7b73be808a83835667"}, + {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f552e4d23edc06cdeb553cdc757d2fe190cdeb10d43093d6a3319f8d4bf1c6b"}, + {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f96f4d5699e486eb1fb95849fe35ab79ab0901265805be7e60f4eaa83ce263ec"}, + {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:368e57fe9ebae05ee5844840fa588028a023d1182b0cfdb1d13f607c9ea05756"}, + {file = "chroma_hnswlib-0.7.3-cp311-cp311-win_amd64.whl", hash = "sha256:b7dca27b8896b494456db0fd705b689ac6b73af78e186eb6a42fea2de4f71c6f"}, + {file = "chroma_hnswlib-0.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:70f897dc6218afa1d99f43a9ad5eb82f392df31f57ff514ccf4eeadecd62f544"}, + {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aef10b4952708f5a1381c124a29aead0c356f8d7d6e0b520b778aaa62a356f4"}, + {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ee2d8d1529fca3898d512079144ec3e28a81d9c17e15e0ea4665697a7923253"}, + {file = "chroma_hnswlib-0.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a4021a70e898783cd6f26e00008b494c6249a7babe8774e90ce4766dd288c8ba"}, + {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a8f61fa1d417fda848e3ba06c07671f14806a2585272b175ba47501b066fe6b1"}, + {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d7563be58bc98e8f0866907368e22ae218d6060601b79c42f59af4eccbbd2e0a"}, + {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51b8d411486ee70d7b66ec08cc8b9b6620116b650df9c19076d2d8b6ce2ae914"}, + {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d706782b628e4f43f1b8a81e9120ac486837fbd9bcb8ced70fe0d9b95c72d77"}, + {file = "chroma_hnswlib-0.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:54f053dedc0e3ba657f05fec6e73dd541bc5db5b09aa8bc146466ffb734bdc86"}, + {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e607c5a71c610a73167a517062d302c0827ccdd6e259af6e4869a5c1306ffb5d"}, + {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c2358a795870156af6761890f9eb5ca8cade57eb10c5f046fe94dae1faa04b9e"}, + {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cea425df2e6b8a5e201fff0d922a1cc1d165b3cfe762b1408075723c8892218"}, + {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:454df3dd3e97aa784fba7cf888ad191e0087eef0fd8c70daf28b753b3b591170"}, + {file = "chroma_hnswlib-0.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:df587d15007ca701c6de0ee7d5585dd5e976b7edd2b30ac72bc376b3c3f85882"}, ] [package.dependencies] @@ -1283,19 +1279,19 @@ numpy = "*" [[package]] name = "chromadb" -version = "0.5.5" +version = "0.5.1" description = "Chroma." optional = false python-versions = ">=3.8" files = [ - {file = "chromadb-0.5.5-py3-none-any.whl", hash = "sha256:2a5a4b84cb0fc32b380e193be68cdbadf3d9f77dbbf141649be9886e42910ddd"}, - {file = "chromadb-0.5.5.tar.gz", hash = "sha256:84f4bfee320fb4912cbeb4d738f01690891e9894f0ba81f39ee02867102a1c4d"}, + {file = "chromadb-0.5.1-py3-none-any.whl", hash = "sha256:61f1f75a672b6edce7f1c8875c67e2aaaaf130dc1c1684431fbc42ad7240d01d"}, + {file = "chromadb-0.5.1.tar.gz", hash = "sha256:e2b2b6a34c2a949bedcaa42fa7775f40c7f6667848fc8094dcbf97fc0d30bee7"}, ] [package.dependencies] bcrypt = ">=4.0.1" build = ">=1.0.3" -chroma-hnswlib = "0.7.6" +chroma-hnswlib = "0.7.3" fastapi = ">=0.95.2" grpcio = ">=1.58.0" httpx = ">=0.27.0" @@ -1314,6 +1310,7 @@ posthog = ">=2.4.0" pydantic = ">=1.9" pypika = ">=0.48.9" PyYAML = ">=6.0.0" +requests = ">=2.28" tenacity = ">=8.2.3" tokenizers = ">=0.13.2" tqdm = ">=4.65.0" @@ -6081,19 +6078,6 @@ files = [ {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, - {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, - {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, - {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] @@ -6970,6 +6954,23 @@ maintainer = ["zest.releaser[recommended]"] pil = ["pillow (>=9.1.0)"] test = ["coverage", "pytest"] +[[package]] +name = "rank-bm25" +version = "0.2.2" +description = "Various BM25 algorithms for document ranking" +optional = false +python-versions = "*" +files = [ + {file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"}, + {file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"}, +] + +[package.dependencies] +numpy = "*" + +[package.extras] +dev = ["pytest"] + [[package]] name = "rapidfuzz" version = "3.9.4" @@ -7503,6 +7504,93 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scikit-learn" +version = "1.5.1" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745"}, + {file = "scikit_learn-1.5.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7"}, + {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac"}, + {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21"}, + {file = "scikit_learn-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1"}, + {file = "scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2"}, + {file = "scikit_learn-1.5.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe"}, + {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4"}, + {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf"}, + {file = "scikit_learn-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b"}, + {file = "scikit_learn-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395"}, + {file = "scikit_learn-1.5.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1"}, + {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915"}, + {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b"}, + {file = "scikit_learn-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74"}, + {file = "scikit_learn-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956"}, + {file = "scikit_learn-1.5.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855"}, + {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1"}, + {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d"}, + {file = "scikit_learn-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d"}, + {file = "scikit_learn-1.5.1.tar.gz", hash = "sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + +[[package]] +name = "scipy" +version = "1.14.0" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.10" +files = [ + {file = "scipy-1.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e911933d54ead4d557c02402710c2396529540b81dd554fc1ba270eb7308484"}, + {file = "scipy-1.14.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:687af0a35462402dd851726295c1a5ae5f987bd6e9026f52e9505994e2f84ef6"}, + {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:07e179dc0205a50721022344fb85074f772eadbda1e1b3eecdc483f8033709b7"}, + {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a9c9a9b226d9a21e0a208bdb024c3982932e43811b62d202aaf1bb59af264b1"}, + {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:076c27284c768b84a45dcf2e914d4000aac537da74236a0d45d82c6fa4b7b3c0"}, + {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42470ea0195336df319741e230626b6225a740fd9dce9642ca13e98f667047c0"}, + {file = "scipy-1.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:176c6f0d0470a32f1b2efaf40c3d37a24876cebf447498a4cefb947a79c21e9d"}, + {file = "scipy-1.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:ad36af9626d27a4326c8e884917b7ec321d8a1841cd6dacc67d2a9e90c2f0359"}, + {file = "scipy-1.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6d056a8709ccda6cf36cdd2eac597d13bc03dba38360f418560a93050c76a16e"}, + {file = "scipy-1.14.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f0a50da861a7ec4573b7c716b2ebdcdf142b66b756a0d392c236ae568b3a93fb"}, + {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:94c164a9e2498e68308e6e148646e486d979f7fcdb8b4cf34b5441894bdb9caf"}, + {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a7d46c3e0aea5c064e734c3eac5cf9eb1f8c4ceee756262f2c7327c4c2691c86"}, + {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eee2989868e274aae26125345584254d97c56194c072ed96cb433f32f692ed8"}, + {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3154691b9f7ed73778d746da2df67a19d046a6c8087c8b385bc4cdb2cfca74"}, + {file = "scipy-1.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c40003d880f39c11c1edbae8144e3813904b10514cd3d3d00c277ae996488cdb"}, + {file = "scipy-1.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:5b083c8940028bb7e0b4172acafda6df762da1927b9091f9611b0bcd8676f2bc"}, + {file = "scipy-1.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff2438ea1330e06e53c424893ec0072640dac00f29c6a43a575cbae4c99b2b9"}, + {file = "scipy-1.14.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bbc0471b5f22c11c389075d091d3885693fd3f5e9a54ce051b46308bc787e5d4"}, + {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:64b2ff514a98cf2bb734a9f90d32dc89dc6ad4a4a36a312cd0d6327170339eb0"}, + {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:7d3da42fbbbb860211a811782504f38ae7aaec9de8764a9bef6b262de7a2b50f"}, + {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d91db2c41dd6c20646af280355d41dfa1ec7eead235642178bd57635a3f82209"}, + {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a01cc03bcdc777c9da3cfdcc74b5a75caffb48a6c39c8450a9a05f82c4250a14"}, + {file = "scipy-1.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65df4da3c12a2bb9ad52b86b4dcf46813e869afb006e58be0f516bc370165159"}, + {file = "scipy-1.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:4c4161597c75043f7154238ef419c29a64ac4a7c889d588ea77690ac4d0d9b20"}, + {file = "scipy-1.14.0.tar.gz", hash = "sha256:b5923f48cb840380f9854339176ef21763118a7300a88203ccd0bdd26e58527b"}, +] + +[package.dependencies] +numpy = ">=1.23.5,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "sentry-sdk" version = "1.44.1" @@ -7882,13 +7970,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1195" +version = "3.0.1196" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1195.tar.gz", hash = "sha256:12acd48a14af327c39edf216bf29c80936f47e351675c4060d695902668ef98d"}, - {file = "tencentcloud_sdk_python_common-3.0.1195-py2.py3-none-any.whl", hash = "sha256:41c21012176f3d5f4e1ba7fbb2cdcac9c4cd62e57e9447abd0076f36146c75f6"}, + {file = "tencentcloud-sdk-python-common-3.0.1196.tar.gz", hash = "sha256:a8acd14f7480987ff0fd1d961ad934b2b7533ab1937d7e3adb74d95dc49954bd"}, + {file = "tencentcloud_sdk_python_common-3.0.1196-py2.py3-none-any.whl", hash = "sha256:5ed438bc3e2818ca8e84b3896aaa2746798fba981bd94b27528eb36efa5b4a30"}, ] [package.dependencies] @@ -7896,17 +7984,28 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1195" +version = "3.0.1196" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1195.tar.gz", hash = "sha256:c22c21fcef7465eb845b694f7901311db78da45ec1e8ea80ec6549f248cb10a7"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1195-py2.py3-none-any.whl", hash = "sha256:729d19889ebe19258b84f10c950971c07c1be665c72608e475b442dc2b79e0c0"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1196.tar.gz", hash = "sha256:ced26497ae5f1b8fcc6cbd12238109274251e82fa1cfedfd6700df776306a36c"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1196-py2.py3-none-any.whl", hash = "sha256:d18a19cffeaf4ff8a60670dc2bdb644f3d7ae6a51c30d21b50ded24a9c542248"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1195" +tencentcloud-sdk-python-common = "3.0.1196" + +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] [[package]] name = "tidb-vector" @@ -9444,4 +9543,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "88dace04a79b56b994195cc627be27336083524724b3520e13bd50fe211b32df" +content-hash = "6b7d8b1333ae9c71ba2e1c5800eecf1535ed3945cd55ebb1e253b7a29ba09559" diff --git a/api/pyproject.toml b/api/pyproject.toml index a316875004..7be3c7af64 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -163,6 +163,7 @@ redis = { version = "~5.0.3", extras = ["hiredis"] } replicate = "~0.22.0" resend = "~0.7.0" safetensors = "~0.4.3" +scikit-learn = "^1.5.1" sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sqlalchemy = "~2.0.29" tencentcloud-sdk-python-hunyuan = "~3.0.1158" @@ -175,7 +176,7 @@ werkzeug = "~3.0.1" xinference-client = "0.9.4" yarl = "~1.9.4" zhipuai = "1.0.7" - +rank-bm25 = "~0.2.2" ############################################################ # Tool dependencies required by tool implementations ############################################################ @@ -200,7 +201,7 @@ cloudscraper = "1.2.71" ############################################################ [tool.poetry.group.vdb.dependencies] -chromadb = "~0.5.1" +chromadb = "0.5.1" oracledb = "~2.2.1" pgvecto-rs = "0.1.4" pgvector = "0.2.5" diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index a04ffdfbbe..69274dff09 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -38,14 +38,16 @@ class HitTestingService: if not retrieval_model: retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), dataset_id=dataset.id, query=cls.escape_query_for_search(query), - top_k=retrieval_model['top_k'], + top_k=retrieval_model.get('top_k', 2), score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None + if retrieval_model['reranking_enable'] else None, + reranking_mode=retrieval_model.get('reranking_mode', None), + weights=retrieval_model.get('weights', None), ) end = time.perf_counter()