mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 21:15:57 +08:00
Fix/reranking mode is null (#7012)
This commit is contained in:
parent
c110888aee
commit
28d4e5b045
@ -28,7 +28,7 @@ class RetrievalService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
||||||
top_k: int, score_threshold: Optional[float] = .0,
|
top_k: int, score_threshold: Optional[float] = .0,
|
||||||
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
|
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
|
||||||
weights: Optional[dict] = None):
|
weights: Optional[dict] = None):
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(
|
||||||
Dataset.id == dataset_id
|
Dataset.id == dataset_id
|
||||||
@ -36,10 +36,6 @@ class RetrievalService:
|
|||||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||||
return []
|
return []
|
||||||
all_documents = []
|
all_documents = []
|
||||||
keyword_search_documents = []
|
|
||||||
embedding_search_documents = []
|
|
||||||
full_text_search_documents = []
|
|
||||||
hybrid_search_documents = []
|
|
||||||
threads = []
|
threads = []
|
||||||
exceptions = []
|
exceptions = []
|
||||||
# retrieval_model source with keyword
|
# retrieval_model source with keyword
|
||||||
|
@ -278,6 +278,7 @@ class DatasetRetrieval:
|
|||||||
query=query,
|
query=query,
|
||||||
top_k=top_k, score_threshold=score_threshold,
|
top_k=top_k, score_threshold=score_threshold,
|
||||||
reranking_model=reranking_model,
|
reranking_model=reranking_model,
|
||||||
|
reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
|
||||||
weights=retrieval_model_config.get('weights', None),
|
weights=retrieval_model_config.get('weights', None),
|
||||||
)
|
)
|
||||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||||
@ -431,10 +432,12 @@ class DatasetRetrieval:
|
|||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
score_threshold=retrieval_model['score_threshold']
|
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||||
if retrieval_model['score_threshold_enabled'] else None,
|
if retrieval_model['score_threshold_enabled'] else None,
|
||||||
reranking_model=retrieval_model['reranking_model']
|
reranking_model=retrieval_model.get('reranking_model', None)
|
||||||
if retrieval_model['reranking_enable'] else None,
|
if retrieval_model['reranking_enable'] else None,
|
||||||
|
reranking_mode=retrieval_model.get('reranking_mode')
|
||||||
|
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||||
weights=retrieval_model.get('weights', None),
|
weights=retrieval_model.get('weights', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -177,10 +177,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
score_threshold=retrieval_model['score_threshold']
|
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||||
if retrieval_model['score_threshold_enabled'] else None,
|
if retrieval_model['score_threshold_enabled'] else None,
|
||||||
reranking_model=retrieval_model['reranking_model']
|
reranking_model=retrieval_model.get('reranking_model', None)
|
||||||
if retrieval_model['reranking_enable'] else None,
|
if retrieval_model['reranking_enable'] else None,
|
||||||
|
reranking_mode=retrieval_model.get('reranking_mode')
|
||||||
|
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||||
weights=retrieval_model.get('weights', None),
|
weights=retrieval_model.get('weights', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ default_retrieval_model = {
|
|||||||
'reranking_provider_name': '',
|
'reranking_provider_name': '',
|
||||||
'reranking_model_name': ''
|
'reranking_model_name': ''
|
||||||
},
|
},
|
||||||
|
'reranking_mode': 'reranking_model',
|
||||||
'top_k': 2,
|
'top_k': 2,
|
||||||
'score_threshold_enabled': False
|
'score_threshold_enabled': False
|
||||||
}
|
}
|
||||||
@ -71,14 +72,15 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
else:
|
else:
|
||||||
if self.top_k > 0:
|
if self.top_k > 0:
|
||||||
# retrieval source
|
# retrieval source
|
||||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
score_threshold=retrieval_model['score_threshold']
|
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||||
if retrieval_model['score_threshold_enabled'] else None,
|
if retrieval_model['score_threshold_enabled'] else None,
|
||||||
reranking_model=retrieval_model['reranking_model']
|
reranking_model=retrieval_model.get('reranking_model', None),
|
||||||
if retrieval_model['reranking_enable'] else None,
|
reranking_mode=retrieval_model.get('reranking_mode')
|
||||||
|
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||||
weights=retrieval_model.get('weights', None),
|
weights=retrieval_model.get('weights', None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -42,11 +42,11 @@ class HitTestingService:
|
|||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=cls.escape_query_for_search(query),
|
query=cls.escape_query_for_search(query),
|
||||||
top_k=retrieval_model.get('top_k', 2),
|
top_k=retrieval_model.get('top_k', 2),
|
||||||
score_threshold=retrieval_model['score_threshold']
|
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||||
if retrieval_model['score_threshold_enabled'] else None,
|
if retrieval_model['score_threshold_enabled'] else None,
|
||||||
reranking_model=retrieval_model['reranking_model']
|
reranking_model=retrieval_model.get('reranking_model', None),
|
||||||
if retrieval_model['reranking_enable'] else None,
|
reranking_mode=retrieval_model.get('reranking_mode')
|
||||||
reranking_mode=retrieval_model.get('reranking_mode', None),
|
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||||
weights=retrieval_model.get('weights', None),
|
weights=retrieval_model.get('weights', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user