mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 23:45:54 +08:00
fix dataset retrival in dataset mode (#3334)
This commit is contained in:
parent
826c422ac4
commit
6164604462
@ -34,6 +34,7 @@ class CSVExtractor(BaseExtractor):
|
|||||||
|
|
||||||
def extract(self) -> list[Document]:
|
def extract(self) -> list[Document]:
|
||||||
"""Load data into document objects."""
|
"""Load data into document objects."""
|
||||||
|
docs = []
|
||||||
try:
|
try:
|
||||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||||
docs = self._read_from_file(csvfile)
|
docs = self._read_from_file(csvfile)
|
||||||
|
@ -2,6 +2,7 @@ import threading
|
|||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||||
@ -17,6 +18,8 @@ from core.rag.models.document import Document
|
|||||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||||
from core.rerank.rerank import RerankRunner
|
from core.rerank.rerank import RerankRunner
|
||||||
|
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||||
|
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
@ -373,3 +376,92 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
|
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
|
|
||||||
|
def to_dataset_retriever_tool(self, tenant_id: str,
|
||||||
|
dataset_ids: list[str],
|
||||||
|
retrieve_config: DatasetRetrieveConfigEntity,
|
||||||
|
return_resource: bool,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||||
|
-> Optional[list[BaseTool]]:
|
||||||
|
"""
|
||||||
|
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param dataset_ids: dataset ids
|
||||||
|
:param retrieve_config: retrieve config
|
||||||
|
:param return_resource: return resource
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:param hit_callback: hit callback
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
available_datasets = []
|
||||||
|
for dataset_id in dataset_ids:
|
||||||
|
# get dataset from dataset id
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# pass if dataset is not available
|
||||||
|
if not dataset:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# pass if dataset is not available
|
||||||
|
if (dataset and dataset.available_document_count == 0
|
||||||
|
and dataset.available_document_count == 0):
|
||||||
|
continue
|
||||||
|
|
||||||
|
available_datasets.append(dataset)
|
||||||
|
|
||||||
|
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||||
|
# get retrieval model config
|
||||||
|
default_retrieval_model = {
|
||||||
|
'search_method': 'semantic_search',
|
||||||
|
'reranking_enable': False,
|
||||||
|
'reranking_model': {
|
||||||
|
'reranking_provider_name': '',
|
||||||
|
'reranking_model_name': ''
|
||||||
|
},
|
||||||
|
'top_k': 2,
|
||||||
|
'score_threshold_enabled': False
|
||||||
|
}
|
||||||
|
|
||||||
|
for dataset in available_datasets:
|
||||||
|
retrieval_model_config = dataset.retrieval_model \
|
||||||
|
if dataset.retrieval_model else default_retrieval_model
|
||||||
|
|
||||||
|
# get top k
|
||||||
|
top_k = retrieval_model_config['top_k']
|
||||||
|
|
||||||
|
# get score threshold
|
||||||
|
score_threshold = None
|
||||||
|
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||||
|
if score_threshold_enabled:
|
||||||
|
score_threshold = retrieval_model_config.get("score_threshold")
|
||||||
|
|
||||||
|
tool = DatasetRetrieverTool.from_dataset(
|
||||||
|
dataset=dataset,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
hit_callbacks=[hit_callback],
|
||||||
|
return_resource=return_resource,
|
||||||
|
retriever_from=invoke_from.to_source()
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||||
|
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||||
|
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
top_k=retrieve_config.top_k or 2,
|
||||||
|
score_threshold=retrieve_config.score_threshold,
|
||||||
|
hit_callbacks=[hit_callback],
|
||||||
|
return_resource=return_resource,
|
||||||
|
retriever_from=invoke_from.to_source(),
|
||||||
|
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||||
|
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
|
||||||
|
)
|
||||||
|
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
Loading…
x
Reference in New Issue
Block a user