Retrieval Service efficiency optimization (#13543)

This commit is contained in:
Charlie.Wei 2025-02-17 14:09:57 +08:00 committed by GitHub
parent 566e548713
commit 222df44d21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 170 additions and 124 deletions

View File

@ -1,3 +1,4 @@
import os
from typing import Any, Literal, Optional
from urllib.parse import quote_plus
@ -166,6 +167,11 @@ class DatabaseConfig(BaseSettings):
default=False,
)
RETRIEVAL_SERVICE_WORKER: NonNegativeInt = Field(
description="If True, enables the retrieval service worker.",
default=os.cpu_count(),
)
@computed_field
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return {

View File

@ -1,9 +1,11 @@
import concurrent.futures
import json
import threading
from typing import Optional
from flask import Flask, current_app
from sqlalchemy.orm import load_only
from configs import dify_config
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
@ -27,6 +29,7 @@ default_retrieval_model = {
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@classmethod
def retrieve(
cls,
@ -41,74 +44,62 @@ class RetrievalService:
):
if not query:
return []
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
dataset = cls._get_dataset(dataset_id)
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents: list[Document] = []
threads: list[threading.Thread] = []
exceptions: list[str] = []
# retrieval_model source with keyword
# Optimize multithreading with thread pools
with concurrent.futures.ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_WORKER) as executor: # type: ignore
futures = []
if retrieval_method == "keyword_search":
keyword_thread = threading.Thread(
target=RetrievalService.keyword_search,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"exceptions": exceptions,
},
futures.append(
executor.submit(
cls.keyword_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
)
)
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if RetrievalMethod.is_support_semantic_search(retrieval_method):
embedding_thread = threading.Thread(
target=RetrievalService.embedding_search,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
"score_threshold": score_threshold,
"reranking_model": reranking_model,
"all_documents": all_documents,
"retrieval_method": retrieval_method,
"exceptions": exceptions,
},
futures.append(
executor.submit(
cls.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
)
)
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
full_text_index_thread = threading.Thread(
target=RetrievalService.full_text_index_search,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"retrieval_method": retrieval_method,
"score_threshold": score_threshold,
"top_k": top_k,
"reranking_model": reranking_model,
"all_documents": all_documents,
"exceptions": exceptions,
},
futures.append(
executor.submit(
cls.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
)
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
)
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
exception_message = ";\n".join(exceptions)
raise ValueError(exception_message)
raise ValueError(";\n".join(exceptions))
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(
@ -133,18 +124,21 @@ class RetrievalService:
)
return all_documents
@classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
keyword = Keyword(dataset=dataset)
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
all_documents.extend(documents)
except Exception as e:
@ -165,12 +159,11 @@ class RetrievalService:
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
query,
search_type="similarity_score_threshold",
@ -187,7 +180,7 @@ class RetrievalService:
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
@ -217,13 +210,11 @@ class RetrievalService:
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
vector_processor = Vector(
dataset=dataset,
)
vector_processor = Vector(dataset=dataset)
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
if documents:
@ -234,7 +225,7 @@ class RetrievalService:
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
@ -253,34 +244,74 @@ class RetrievalService:
def escape_query_for_search(query: str) -> str:
return json.dumps(query).strip('"')
@staticmethod
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
@classmethod
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
"""Format retrieval documents with optimized batch processing"""
if not documents:
return []
try:
# Collect document IDs
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata}
if not document_ids:
return []
# Batch query dataset documents
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
.filter(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
}
records = []
include_segment_ids = []
include_segment_ids = set()
segment_child_map = {}
# Process documents
for document in documents:
document_id = document.metadata.get("document_id")
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document:
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
child_chunk = (
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
)
if not child_chunk:
continue
segment = (
db.session.query(DocumentSegment)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
DocumentSegment.doc_metadata,
)
)
.first()
)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
include_segment_ids.add(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
@ -308,9 +339,10 @@ class RetrievalService:
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
# Handle normal documents
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
else:
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
@ -325,16 +357,24 @@ class RetrievalService:
if not segment:
continue
include_segment_ids.append(segment.id)
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
"score": document.metadata.get("score"), # type: ignore
"segment_metadata": segment.doc_metadata,
}
records.append(record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
return [RetrievalSegments(**record) for record in records]
except Exception as e:
db.session.rollback()
raise e
finally:
db.session.close()