mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-01 01:02:01 +08:00
Support knowledge metadata filter (#15982)
This commit is contained in:
parent
b65f2eb55f
commit
abeaea4f79
@ -81,6 +81,7 @@ from .datasets import (
|
|||||||
datasets_segments,
|
datasets_segments,
|
||||||
external,
|
external,
|
||||||
hit_testing,
|
hit_testing,
|
||||||
|
metadata,
|
||||||
website,
|
website,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -621,7 +621,7 @@ class DocumentDetailApi(DocumentResource):
|
|||||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||||
|
|
||||||
if metadata == "only":
|
if metadata == "only":
|
||||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
|
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||||
elif metadata == "without":
|
elif metadata == "without":
|
||||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||||
document_process_rules = document.dataset_process_rule.to_dict()
|
document_process_rules = document.dataset_process_rule.to_dict()
|
||||||
@ -682,7 +682,7 @@ class DocumentDetailApi(DocumentResource):
|
|||||||
"disabled_by": document.disabled_by,
|
"disabled_by": document.disabled_by,
|
||||||
"archived": document.archived,
|
"archived": document.archived,
|
||||||
"doc_type": document.doc_type,
|
"doc_type": document.doc_type,
|
||||||
"doc_metadata": document.doc_metadata,
|
"doc_metadata": document.doc_metadata_details,
|
||||||
"segment_count": document.segment_count,
|
"segment_count": document.segment_count,
|
||||||
"average_segment_length": document.average_segment_length,
|
"average_segment_length": document.average_segment_length,
|
||||||
"hit_count": document.hit_count,
|
"hit_count": document.hit_count,
|
||||||
|
155
api/controllers/console/datasets/metadata.py
Normal file
155
api/controllers/console/datasets/metadata.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from flask_login import current_user # type: ignore # type: ignore
|
||||||
|
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||||
|
from fields.dataset_fields import dataset_metadata_fields
|
||||||
|
from libs.login import login_required
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
|
MetadataArgs,
|
||||||
|
MetadataOperationData,
|
||||||
|
)
|
||||||
|
from services.metadata_service import MetadataService
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_name(name):
|
||||||
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_description_length(description):
|
||||||
|
if len(description) > 400:
|
||||||
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMetadataCreateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
@marshal_with(dataset_metadata_fields)
|
||||||
|
def post(self, dataset_id):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||||
|
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
metadata_args = MetadataArgs(**args)
|
||||||
|
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
|
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||||
|
return metadata, 201
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def get(self, dataset_id):
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMetadataApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
@marshal_with(dataset_metadata_fields)
|
||||||
|
def patch(self, dataset_id, metadata_id):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
metadata_id_str = str(metadata_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
|
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||||
|
return metadata, 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def delete(self, dataset_id, metadata_id):
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
metadata_id_str = str(metadata_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
|
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def get(self):
|
||||||
|
built_in_fields = MetadataService.get_built_in_fields()
|
||||||
|
return {"fields": built_in_fields}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def post(self, dataset_id, action):
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
|
if action == "enable":
|
||||||
|
MetadataService.enable_built_in_field(dataset)
|
||||||
|
elif action == "disable":
|
||||||
|
MetadataService.disable_built_in_field(dataset)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentMetadataEditApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def post(self, dataset_id):
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
metadata_args = MetadataOperationData(**args)
|
||||||
|
|
||||||
|
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||||
|
|
||||||
|
return 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||||
|
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||||
|
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
||||||
|
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||||
|
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
@ -1,7 +1,12 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import (
|
||||||
|
DatasetEntity,
|
||||||
|
DatasetRetrieveConfigEntity,
|
||||||
|
MetadataFilteringCondition,
|
||||||
|
ModelConfig,
|
||||||
|
)
|
||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
@ -78,6 +83,15 @@ class DatasetConfigManager:
|
|||||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||||
dataset_configs["retrieval_model"]
|
dataset_configs["retrieval_model"]
|
||||||
),
|
),
|
||||||
|
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||||
|
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||||
|
if dataset_configs.get("metadata_model_config")
|
||||||
|
else None,
|
||||||
|
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||||
|
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||||
|
)
|
||||||
|
if dataset_configs.get("metadata_filtering_conditions")
|
||||||
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -96,6 +110,15 @@ class DatasetConfigManager:
|
|||||||
weights=dataset_configs.get("weights"),
|
weights=dataset_configs.get("weights"),
|
||||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||||
|
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
|
||||||
|
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
|
||||||
|
if dataset_configs.get("metadata_model_config")
|
||||||
|
else None,
|
||||||
|
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||||
|
**dataset_configs.get("metadata_filtering_conditions", {})
|
||||||
|
)
|
||||||
|
if dataset_configs.get("metadata_filtering_conditions")
|
||||||
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Any, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
@ -135,6 +136,55 @@ class ExternalDataVariableEntity(BaseModel):
|
|||||||
config: dict[str, Any] = Field(default_factory=dict)
|
config: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
SupportedComparisonOperator = Literal[
|
||||||
|
# for string or array
|
||||||
|
"contains",
|
||||||
|
"not contains",
|
||||||
|
"start with",
|
||||||
|
"end with",
|
||||||
|
"is",
|
||||||
|
"is not",
|
||||||
|
"empty",
|
||||||
|
"not empty",
|
||||||
|
# for number
|
||||||
|
"=",
|
||||||
|
"≠",
|
||||||
|
">",
|
||||||
|
"<",
|
||||||
|
"≥",
|
||||||
|
"≤",
|
||||||
|
# for time
|
||||||
|
"before",
|
||||||
|
"after",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseModel):
|
||||||
|
provider: str
|
||||||
|
name: str
|
||||||
|
mode: LLMMode
|
||||||
|
completion_params: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(BaseModel):
|
||||||
|
"""
|
||||||
|
Conditon detail
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
comparison_operator: SupportedComparisonOperator
|
||||||
|
value: str | Sequence[str] | None | int | float = None
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilteringCondition(BaseModel):
|
||||||
|
"""
|
||||||
|
Metadata Filtering Condition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||||
|
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieveConfigEntity(BaseModel):
|
class DatasetRetrieveConfigEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
Dataset Retrieve Config Entity.
|
Dataset Retrieve Config Entity.
|
||||||
@ -171,6 +221,9 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||||||
reranking_model: Optional[dict] = None
|
reranking_model: Optional[dict] = None
|
||||||
weights: Optional[dict] = None
|
weights: Optional[dict] = None
|
||||||
reranking_enabled: Optional[bool] = True
|
reranking_enabled: Optional[bool] = True
|
||||||
|
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||||
|
metadata_model_config: Optional[ModelConfig] = None
|
||||||
|
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||||
|
|
||||||
|
|
||||||
class DatasetEntity(BaseModel):
|
class DatasetEntity(BaseModel):
|
||||||
|
@ -180,6 +180,7 @@ class ChatAppRunner(AppRunner):
|
|||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
|
inputs=inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# reorganize all inputs and template to prompt messages
|
# reorganize all inputs and template to prompt messages
|
||||||
|
@ -139,6 +139,7 @@ class CompletionAppRunner(AppRunner):
|
|||||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
|
inputs=inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# reorganize all inputs and template to prompt messages
|
# reorganize all inputs and template to prompt messages
|
||||||
|
@ -88,16 +88,17 @@ class Jieba(BaseKeyword):
|
|||||||
keyword_table = self._get_dataset_keyword_table()
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
|
||||||
k = kwargs.get("top_k", 4)
|
k = kwargs.get("top_k", 4)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
for chunk_index in sorted_chunk_indices:
|
for chunk_index in sorted_chunk_indices:
|
||||||
segment = (
|
segment_query = db.session.query(DocumentSegment).filter(
|
||||||
db.session.query(DocumentSegment)
|
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||||
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
if document_ids_filter:
|
||||||
|
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
|
||||||
|
segment = segment_query.first()
|
||||||
|
|
||||||
if segment:
|
if segment:
|
||||||
documents.append(
|
documents.append(
|
||||||
|
@ -41,6 +41,7 @@ class RetrievalService:
|
|||||||
reranking_model: Optional[dict] = None,
|
reranking_model: Optional[dict] = None,
|
||||||
reranking_mode: str = "reranking_model",
|
reranking_mode: str = "reranking_model",
|
||||||
weights: Optional[dict] = None,
|
weights: Optional[dict] = None,
|
||||||
|
document_ids_filter: Optional[list[str]] = None,
|
||||||
):
|
):
|
||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
@ -64,6 +65,7 @@ class RetrievalService:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
all_documents=all_documents,
|
all_documents=all_documents,
|
||||||
exceptions=exceptions,
|
exceptions=exceptions,
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||||
@ -79,6 +81,7 @@ class RetrievalService:
|
|||||||
all_documents=all_documents,
|
all_documents=all_documents,
|
||||||
retrieval_method=retrieval_method,
|
retrieval_method=retrieval_method,
|
||||||
exceptions=exceptions,
|
exceptions=exceptions,
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||||
@ -130,7 +133,14 @@ class RetrievalService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def keyword_search(
|
def keyword_search(
|
||||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
cls,
|
||||||
|
flask_app: Flask,
|
||||||
|
dataset_id: str,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
all_documents: list,
|
||||||
|
exceptions: list,
|
||||||
|
document_ids_filter: Optional[list[str]] = None,
|
||||||
):
|
):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
try:
|
try:
|
||||||
@ -139,7 +149,10 @@ class RetrievalService:
|
|||||||
raise ValueError("dataset not found")
|
raise ValueError("dataset not found")
|
||||||
|
|
||||||
keyword = Keyword(dataset=dataset)
|
keyword = Keyword(dataset=dataset)
|
||||||
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
|
|
||||||
|
documents = keyword.search(
|
||||||
|
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
|
||||||
|
)
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exceptions.append(str(e))
|
exceptions.append(str(e))
|
||||||
@ -156,6 +169,7 @@ class RetrievalService:
|
|||||||
all_documents: list,
|
all_documents: list,
|
||||||
retrieval_method: str,
|
retrieval_method: str,
|
||||||
exceptions: list,
|
exceptions: list,
|
||||||
|
document_ids_filter: Optional[list[str]] = None,
|
||||||
):
|
):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
try:
|
try:
|
||||||
@ -170,6 +184,7 @@ class RetrievalService:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
filter={"group_id": [dataset.id]},
|
filter={"group_id": [dataset.id]},
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
if documents:
|
if documents:
|
||||||
|
@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector):
|
|||||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
return self.analyticdb_vector.search_by_vector(query_vector)
|
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||||
|
@ -196,6 +196,11 @@ class AnalyticdbVectorBySql:
|
|||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
if not isinstance(top_k, int) or top_k <= 0:
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
raise ValueError("top_k must be a positive integer")
|
raise ValueError("top_k must be a positive integer")
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = "WHERE 1=1"
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
query_vector_str = json.dumps(query_vector)
|
query_vector_str = json.dumps(query_vector)
|
||||||
@ -204,7 +209,7 @@ class AnalyticdbVectorBySql:
|
|||||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
|
||||||
(query_vector_str,),
|
(query_vector_str,),
|
||||||
)
|
)
|
||||||
documents = []
|
documents = []
|
||||||
@ -224,12 +229,17 @@ class AnalyticdbVectorBySql:
|
|||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
if not isinstance(top_k, int) or top_k <= 0:
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
raise ValueError("top_k must be a positive integer")
|
raise ValueError("top_k must be a positive integer")
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = ""
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""SELECT id, vector, page_content, metadata_,
|
f"""SELECT id, vector, page_content, metadata_,
|
||||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||||
FROM {self.table_name}
|
FROM {self.table_name}
|
||||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT {top_k}""",
|
LIMIT {top_k}""",
|
||||||
(f"'{query}'", f"'{query}'"),
|
(f"'{query}'", f"'{query}'"),
|
||||||
|
@ -123,11 +123,21 @@ class BaiduVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||||
anns = AnnSearch(
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
vector_field=self.field_vector,
|
if document_ids_filter:
|
||||||
vector_floats=query_vector,
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
anns = AnnSearch(
|
||||||
)
|
vector_field=self.field_vector,
|
||||||
|
vector_floats=query_vector,
|
||||||
|
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||||
|
filter=f"document_id IN ({document_ids})",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
anns = AnnSearch(
|
||||||
|
vector_field=self.field_vector,
|
||||||
|
vector_floats=query_vector,
|
||||||
|
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||||
|
)
|
||||||
res = self._db.table(self._collection_name).search(
|
res = self._db.table(self._collection_name).search(
|
||||||
anns=anns,
|
anns=anns,
|
||||||
projections=[self.field_id, self.field_text, self.field_metadata],
|
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||||
|
@ -95,7 +95,15 @@ class ChromaVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
collection = self._client.get_or_create_collection(self._collection_name)
|
collection = self._client.get_or_create_collection(self._collection_name)
|
||||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
results: QueryResult = collection.query(
|
||||||
|
query_embeddings=query_vector,
|
||||||
|
n_results=kwargs.get("top_k", 4),
|
||||||
|
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
|
||||||
# Check if results contain data
|
# Check if results contain data
|
||||||
|
@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector):
|
|||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
num_candidates = math.ceil(top_k * 1.5)
|
num_candidates = math.ceil(top_k * 1.5)
|
||||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||||
|
|
||||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||||
|
|
||||||
@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
|
||||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||||
docs = []
|
docs = []
|
||||||
for hit in results["hits"]["hits"]:
|
for hit in results["hits"]["hits"]:
|
||||||
|
@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector):
|
|||||||
raise ValueError("All elements in query_vector should be floats")
|
raise ValueError("All elements in query_vector should be floats")
|
||||||
|
|
||||||
top_k = kwargs.get("top_k", 10)
|
top_k = kwargs.get("top_k", 10)
|
||||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
filters = []
|
||||||
|
if document_ids_filter:
|
||||||
|
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||||
|
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
params = {}
|
params = {}
|
||||||
if self._using_ugc:
|
if self._using_ugc:
|
||||||
@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector):
|
|||||||
should = kwargs.get("should")
|
should = kwargs.get("should")
|
||||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||||
top_k = kwargs.get("top_k", 10)
|
top_k = kwargs.get("top_k", 10)
|
||||||
filters = kwargs.get("filter")
|
filters = kwargs.get("filter", [])
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||||
routing = self._routing
|
routing = self._routing
|
||||||
full_text_query = default_text_search_query(
|
full_text_query = default_text_search_query(
|
||||||
query_text=query,
|
query_text=query,
|
||||||
|
@ -228,12 +228,18 @@ class MilvusVector(BaseVector):
|
|||||||
"""
|
"""
|
||||||
Search for documents by vector similarity.
|
Search for documents by vector similarity.
|
||||||
"""
|
"""
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
filter = ""
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
filter = f'metadata["document_id"] in ({document_ids})'
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
data=[query_vector],
|
data=[query_vector],
|
||||||
anns_field=Field.VECTOR.value,
|
anns_field=Field.VECTOR.value,
|
||||||
limit=kwargs.get("top_k", 4),
|
limit=kwargs.get("top_k", 4),
|
||||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._process_search_results(
|
return self._process_search_results(
|
||||||
@ -249,6 +255,11 @@ class MilvusVector(BaseVector):
|
|||||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||||
return []
|
return []
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
filter = ""
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
filter = f'metadata["document_id"] in ({document_ids})'
|
||||||
|
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
@ -256,6 +267,7 @@ class MilvusVector(BaseVector):
|
|||||||
anns_field=Field.SPARSE_VECTOR.value,
|
anns_field=Field.SPARSE_VECTOR.value,
|
||||||
limit=kwargs.get("top_k", 4),
|
limit=kwargs.get("top_k", 4),
|
||||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
|
filter=filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._process_search_results(
|
return self._process_search_results(
|
||||||
|
@ -133,6 +133,10 @@ class MyScaleVector(BaseVector):
|
|||||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||||
|
@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = None
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||||
if ef_search != self._hnsw_ef_search:
|
if ef_search != self._hnsw_ef_search:
|
||||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||||
@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector):
|
|||||||
distance_func=func.l2_distance,
|
distance_func=func.l2_distance,
|
||||||
output_column_names=["text", "metadata"],
|
output_column_names=["text", "metadata"],
|
||||||
with_dist=True,
|
with_dist=True,
|
||||||
|
where_clause=where_clause,
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
for text, metadata, distance in cur:
|
for text, metadata, distance in cur:
|
||||||
|
@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector):
|
|||||||
"size": kwargs.get("top_k", 4),
|
"size": kwargs.get("top_k", 4),
|
||||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||||
}
|
}
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||||
@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||||
|
|
||||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||||
|
|
||||||
|
@ -201,10 +201,15 @@ class OracleVector(BaseVector):
|
|||||||
:return: List of Documents that are nearest to the query vector.
|
:return: List of Documents that are nearest to the query vector.
|
||||||
"""
|
"""
|
||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = ""
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||||
f" ORDER BY distance fetch first {top_k} rows only",
|
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
|
||||||
[numpy.array(query_vector)],
|
[numpy.array(query_vector)],
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
@ -257,9 +262,15 @@ class OracleVector(BaseVector):
|
|||||||
if token not in stop_words:
|
if token not in stop_words:
|
||||||
entities.append(token)
|
entities.append(token)
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = ""
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"select meta, text, embedding FROM {self.table_name}"
|
f"select meta, text, embedding FROM {self.table_name}"
|
||||||
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||||
|
f"order by score(1) desc fetch first {top_k} rows only",
|
||||||
[" ACCUM ".join(entities)],
|
[" ACCUM ".join(entities)],
|
||||||
)
|
)
|
||||||
docs = []
|
docs = []
|
||||||
|
@ -189,6 +189,9 @@ class PGVectoRS(BaseVector):
|
|||||||
.limit(kwargs.get("top_k", 4))
|
.limit(kwargs.get("top_k", 4))
|
||||||
.order_by("distance")
|
.order_by("distance")
|
||||||
)
|
)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
|
||||||
res = session.execute(stmt)
|
res = session.execute(stmt)
|
||||||
results = [(row[0], row[1]) for row in res]
|
results = [(row[0], row[1]) for row in res]
|
||||||
|
|
||||||
|
@ -173,10 +173,16 @@ class PGVector(BaseVector):
|
|||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
if not isinstance(top_k, int) or top_k <= 0:
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
raise ValueError("top_k must be a positive integer")
|
raise ValueError("top_k must be a positive integer")
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = ""
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
|
||||||
|
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||||
|
f" {where_clause}"
|
||||||
f" ORDER BY distance LIMIT {top_k}",
|
f" ORDER BY distance LIMIT {top_k}",
|
||||||
(json.dumps(query_vector),),
|
(json.dumps(query_vector),),
|
||||||
)
|
)
|
||||||
@ -195,12 +201,18 @@ class PGVector(BaseVector):
|
|||||||
if not isinstance(top_k, int) or top_k <= 0:
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
raise ValueError("top_k must be a positive integer")
|
raise ValueError("top_k must be a positive integer")
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = ""
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||||
if self.pg_bigm:
|
if self.pg_bigm:
|
||||||
cur.execute("SET pg_bigm.similarity_limit TO 0.000001")
|
cur.execute("SET pg_bigm.similarity_limit TO 0.000001")
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score
|
f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score
|
||||||
FROM {self.table_name}
|
FROM {self.table_name}
|
||||||
WHERE text =%% unistr(%s)
|
WHERE text =%% unistr(%s)
|
||||||
|
{where_clause}
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT {top_k}""",
|
LIMIT {top_k}""",
|
||||||
# f"'{query}'" is required in order to account for whitespace in query
|
# f"'{query}'" is required in order to account for whitespace in query
|
||||||
@ -211,6 +223,7 @@ class PGVector(BaseVector):
|
|||||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||||
FROM {self.table_name}
|
FROM {self.table_name}
|
||||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||||
|
{where_clause}
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT {top_k}""",
|
LIMIT {top_k}""",
|
||||||
# f"'{query}'" is required in order to account for whitespace in query
|
# f"'{query}'" is required in order to account for whitespace in query
|
||||||
|
@ -286,27 +286,26 @@ class QdrantVector(BaseVector):
|
|||||||
from qdrant_client.http import models
|
from qdrant_client.http import models
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
|
|
||||||
for node_id in ids:
|
try:
|
||||||
try:
|
filter = models.Filter(
|
||||||
filter = models.Filter(
|
must=[
|
||||||
must=[
|
models.FieldCondition(
|
||||||
models.FieldCondition(
|
key="metadata.doc_id",
|
||||||
key="metadata.doc_id",
|
match=models.MatchAny(any=ids),
|
||||||
match=models.MatchValue(value=node_id),
|
),
|
||||||
),
|
],
|
||||||
],
|
)
|
||||||
)
|
self._client.delete(
|
||||||
self._client.delete(
|
collection_name=self._collection_name,
|
||||||
collection_name=self._collection_name,
|
points_selector=FilterSelector(filter=filter),
|
||||||
points_selector=FilterSelector(filter=filter),
|
)
|
||||||
)
|
except UnexpectedResponse as e:
|
||||||
except UnexpectedResponse as e:
|
# Collection does not exist, so return
|
||||||
# Collection does not exist, so return
|
if e.status_code == 404:
|
||||||
if e.status_code == 404:
|
return
|
||||||
return
|
# Some other error occurred, so re-raise the exception
|
||||||
# Some other error occurred, so re-raise the exception
|
else:
|
||||||
else:
|
raise e
|
||||||
raise e
|
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
all_collection_name = []
|
all_collection_name = []
|
||||||
@ -331,6 +330,15 @@ class QdrantVector(BaseVector):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
if filter.must:
|
||||||
|
filter.must.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key="metadata.document_id",
|
||||||
|
match=models.MatchAny(any=document_ids_filter),
|
||||||
|
)
|
||||||
|
)
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
@ -377,6 +385,15 @@ class QdrantVector(BaseVector):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
if scroll_filter.must:
|
||||||
|
scroll_filter.must.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key="metadata.document_id",
|
||||||
|
match=models.MatchAny(any=document_ids_filter),
|
||||||
|
)
|
||||||
|
)
|
||||||
response = self._client.scroll(
|
response = self._client.scroll(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
scroll_filter=scroll_filter,
|
scroll_filter=scroll_filter,
|
||||||
|
@ -223,8 +223,12 @@ class RelytVector(BaseVector):
|
|||||||
return len(result) > 0
|
return len(result) > 0
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
filter = kwargs.get("filter", {})
|
||||||
|
if document_ids_filter:
|
||||||
|
filter["document_id"] = document_ids_filter
|
||||||
results = self.similarity_search_with_score_by_vector(
|
results = self.similarity_search_with_score_by_vector(
|
||||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
|
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
|
||||||
)
|
)
|
||||||
|
|
||||||
# Organize results.
|
# Organize results.
|
||||||
@ -246,9 +250,9 @@ class RelytVector(BaseVector):
|
|||||||
filter_condition = ""
|
filter_condition = ""
|
||||||
if filter is not None:
|
if filter is not None:
|
||||||
conditions = [
|
conditions = [
|
||||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
|
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
|
||||||
if len(value) > 1
|
if len(value) > 1
|
||||||
else f"metadata->>{key!r} = {value[0]!r}"
|
else f"metadata->>'{key!r}' = {value[0]!r}"
|
||||||
for key, value in filter.items()
|
for key, value in filter.items()
|
||||||
]
|
]
|
||||||
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
||||||
|
@ -145,11 +145,16 @@ class TencentVector(BaseVector):
|
|||||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
filter = None
|
||||||
|
if document_ids_filter:
|
||||||
|
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||||
res = self._db.collection(self._collection_name).search(
|
res = self._db.collection(self._collection_name).search(
|
||||||
vectors=[query_vector],
|
vectors=[query_vector],
|
||||||
|
filter=filter,
|
||||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||||
retrieve_vector=False,
|
retrieve_vector=False,
|
||||||
limit=kwargs.get("top_k", 4),
|
limit=kwargs.get("top_k", 4),
|
||||||
|
@ -326,6 +326,18 @@ class TidbOnQdrantVector(BaseVector):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
should_conditions = []
|
||||||
|
for document_id_filter in document_ids_filter:
|
||||||
|
should_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key="metadata.document_id",
|
||||||
|
match=models.MatchValue(value=document_id_filter),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if should_conditions:
|
||||||
|
filter.should = should_conditions # type: ignore
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
@ -368,6 +380,18 @@ class TidbOnQdrantVector(BaseVector):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
should_conditions = []
|
||||||
|
for document_id_filter in document_ids_filter:
|
||||||
|
should_conditions.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key="metadata.document_id",
|
||||||
|
match=models.MatchValue(value=document_id_filter),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if should_conditions:
|
||||||
|
scroll_filter.should = should_conditions # type: ignore
|
||||||
response = self._client.scroll(
|
response = self._client.scroll(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
scroll_filter=scroll_filter,
|
scroll_filter=scroll_filter,
|
||||||
|
@ -196,6 +196,11 @@ class TiDBVector(BaseVector):
|
|||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
tidb_dist_func = self._get_distance_func()
|
tidb_dist_func = self._get_distance_func()
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
where_clause = ""
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
|
||||||
|
|
||||||
with Session(self._engine) as session:
|
with Session(self._engine) as session:
|
||||||
select_statement = sql_text(f"""
|
select_statement = sql_text(f"""
|
||||||
@ -206,6 +211,7 @@ class TiDBVector(BaseVector):
|
|||||||
text,
|
text,
|
||||||
{tidb_dist_func}(vector, :query_vector_str) AS distance
|
{tidb_dist_func}(vector, :query_vector_str) AS distance
|
||||||
FROM {self._collection_name}
|
FROM {self._collection_name}
|
||||||
|
{where_clause}
|
||||||
ORDER BY distance ASC
|
ORDER BY distance ASC
|
||||||
LIMIT :top_k
|
LIMIT :top_k
|
||||||
) t
|
) t
|
||||||
|
@ -88,7 +88,20 @@ class UpstashVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
filter = f"document_id in ({document_ids})"
|
||||||
|
else:
|
||||||
|
filter = ""
|
||||||
|
result = self.index.query(
|
||||||
|
vector=query_vector,
|
||||||
|
top_k=top_k,
|
||||||
|
include_metadata=True,
|
||||||
|
include_data=True,
|
||||||
|
include_vectors=False,
|
||||||
|
filter=filter,
|
||||||
|
)
|
||||||
docs = []
|
docs = []
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
for record in result:
|
for record in result:
|
||||||
|
@ -177,7 +177,11 @@ class VikingDBVector(BaseVector):
|
|||||||
query_vector, limit=kwargs.get("top_k", 4)
|
query_vector, limit=kwargs.get("top_k", 4)
|
||||||
)
|
)
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
return self._get_search_res(results, score_threshold)
|
docs = self._get_search_res(results, score_threshold)
|
||||||
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
if document_ids_filter:
|
||||||
|
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
|
||||||
|
return docs
|
||||||
|
|
||||||
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
|
@ -187,8 +187,10 @@ class WeaviateVector(BaseVector):
|
|||||||
query_obj = self._client.query.get(collection_name, properties)
|
query_obj = self._client.query.get(collection_name, properties)
|
||||||
|
|
||||||
vector = {"vector": query_vector}
|
vector = {"vector": query_vector}
|
||||||
if kwargs.get("where_filter"):
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
if document_ids_filter:
|
||||||
|
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||||
|
query_obj = query_obj.with_where(where_filter)
|
||||||
result = (
|
result = (
|
||||||
query_obj.with_near_vector(vector)
|
query_obj.with_near_vector(vector)
|
||||||
.with_limit(kwargs.get("top_k", 4))
|
.with_limit(kwargs.get("top_k", 4))
|
||||||
@ -233,8 +235,10 @@ class WeaviateVector(BaseVector):
|
|||||||
if kwargs.get("search_distance"):
|
if kwargs.get("search_distance"):
|
||||||
content["certainty"] = kwargs.get("search_distance")
|
content["certainty"] = kwargs.get("search_distance")
|
||||||
query_obj = self._client.query.get(collection_name, properties)
|
query_obj = self._client.query.get(collection_name, properties)
|
||||||
if kwargs.get("where_filter"):
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
if document_ids_filter:
|
||||||
|
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||||
|
query_obj = query_obj.with_where(where_filter)
|
||||||
query_obj = query_obj.with_additional(["vector"])
|
query_obj = query_obj.with_additional(["vector"])
|
||||||
properties = ["text"]
|
properties = ["text"]
|
||||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
|
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
|
||||||
|
45
api/core/rag/entities/metadata_entities.py
Normal file
45
api/core/rag/entities/metadata_entities.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
SupportedComparisonOperator = Literal[
|
||||||
|
# for string or array
|
||||||
|
"contains",
|
||||||
|
"not contains",
|
||||||
|
"start with",
|
||||||
|
"end with",
|
||||||
|
"is",
|
||||||
|
"is not",
|
||||||
|
"empty",
|
||||||
|
"not empty",
|
||||||
|
# for number
|
||||||
|
"=",
|
||||||
|
"≠",
|
||||||
|
">",
|
||||||
|
"<",
|
||||||
|
"≥",
|
||||||
|
"≤",
|
||||||
|
# for time
|
||||||
|
"before",
|
||||||
|
"after",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(BaseModel):
|
||||||
|
"""
|
||||||
|
Conditon detail
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
comparison_operator: SupportedComparisonOperator
|
||||||
|
value: str | Sequence[str] | None | int | float = None
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataCondition(BaseModel):
|
||||||
|
"""
|
||||||
|
Metadata Condition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||||
|
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
15
api/core/rag/index_processor/constant/built_in_field.py
Normal file
15
api/core/rag/index_processor/constant/built_in_field.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltInField(str, Enum):
|
||||||
|
document_name = "document_name"
|
||||||
|
uploader = "uploader"
|
||||||
|
upload_date = "upload_date"
|
||||||
|
last_update_date = "last_update_date"
|
||||||
|
source = "source"
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataDataSource(Enum):
|
||||||
|
upload_file = "file_upload"
|
||||||
|
website_crawl = "website"
|
||||||
|
notion_import = "notion"
|
@ -1,35 +1,61 @@
|
|||||||
|
import json
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
from collections import Counter
|
from collections import Counter, defaultdict
|
||||||
from typing import Any, Optional, cast
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
from sqlalchemy import Integer, and_, or_, text
|
||||||
|
from sqlalchemy import cast as sqlalchemy_cast
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import (
|
||||||
|
DatasetEntity,
|
||||||
|
DatasetRetrieveConfigEntity,
|
||||||
|
MetadataFilteringCondition,
|
||||||
|
ModelConfig,
|
||||||
|
)
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
from core.ops.utils import measure_time
|
from core.ops.utils import measure_time
|
||||||
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||||
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
from core.rag.entities.context_entities import DocumentContext
|
||||||
|
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.rerank.rerank_type import RerankMode
|
from core.rag.rerank.rerank_type import RerankMode
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||||
|
from core.rag.retrieval.template_prompts import (
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||||
|
METADATA_FILTER_COMPLETION_PROMPT,
|
||||||
|
METADATA_FILTER_SYSTEM_PROMPT,
|
||||||
|
METADATA_FILTER_USER_PROMPT_1,
|
||||||
|
METADATA_FILTER_USER_PROMPT_2,
|
||||||
|
METADATA_FILTER_USER_PROMPT_3,
|
||||||
|
)
|
||||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
|
|
||||||
@ -59,6 +85,7 @@ class DatasetRetrieval:
|
|||||||
hit_callback: DatasetIndexToolCallbackHandler,
|
hit_callback: DatasetIndexToolCallbackHandler,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
memory: Optional[TokenBufferMemory] = None,
|
memory: Optional[TokenBufferMemory] = None,
|
||||||
|
inputs: Optional[Mapping[str, Any]] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Retrieve dataset.
|
Retrieve dataset.
|
||||||
@ -116,6 +143,22 @@ class DatasetRetrieval:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
|
if inputs:
|
||||||
|
inputs = {key: str(value) for key, value in inputs.items()}
|
||||||
|
else:
|
||||||
|
inputs = {}
|
||||||
|
available_datasets_ids = [dataset.id for dataset in available_datasets]
|
||||||
|
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
||||||
|
available_datasets_ids,
|
||||||
|
query,
|
||||||
|
tenant_id,
|
||||||
|
user_id,
|
||||||
|
retrieve_config.metadata_filtering_mode, # type: ignore
|
||||||
|
retrieve_config.metadata_model_config, # type: ignore
|
||||||
|
retrieve_config.metadata_filtering_conditions,
|
||||||
|
inputs,
|
||||||
|
)
|
||||||
|
|
||||||
all_documents = []
|
all_documents = []
|
||||||
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
||||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||||
@ -130,6 +173,8 @@ class DatasetRetrieval:
|
|||||||
model_config,
|
model_config,
|
||||||
planning_strategy,
|
planning_strategy,
|
||||||
message_id,
|
message_id,
|
||||||
|
metadata_filter_document_ids,
|
||||||
|
metadata_condition,
|
||||||
)
|
)
|
||||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||||
all_documents = self.multiple_retrieve(
|
all_documents = self.multiple_retrieve(
|
||||||
@ -146,6 +191,8 @@ class DatasetRetrieval:
|
|||||||
retrieve_config.weights,
|
retrieve_config.weights,
|
||||||
retrieve_config.reranking_enabled or True,
|
retrieve_config.reranking_enabled or True,
|
||||||
message_id,
|
message_id,
|
||||||
|
metadata_filter_document_ids,
|
||||||
|
metadata_condition,
|
||||||
)
|
)
|
||||||
|
|
||||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||||
@ -239,6 +286,8 @@ class DatasetRetrieval:
|
|||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
planning_strategy: PlanningStrategy,
|
planning_strategy: PlanningStrategy,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
||||||
|
metadata_condition: Optional[MetadataCondition] = None,
|
||||||
):
|
):
|
||||||
tools = []
|
tools = []
|
||||||
for dataset in available_datasets:
|
for dataset in available_datasets:
|
||||||
@ -279,6 +328,7 @@ class DatasetRetrieval:
|
|||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
query=query,
|
query=query,
|
||||||
external_retrieval_parameters=dataset.retrieval_model,
|
external_retrieval_parameters=dataset.retrieval_model,
|
||||||
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
for external_document in external_documents:
|
for external_document in external_documents:
|
||||||
document = Document(
|
document = Document(
|
||||||
@ -293,6 +343,15 @@ class DatasetRetrieval:
|
|||||||
document.metadata["dataset_name"] = dataset.name
|
document.metadata["dataset_name"] = dataset.name
|
||||||
results.append(document)
|
results.append(document)
|
||||||
else:
|
else:
|
||||||
|
if metadata_condition and not metadata_filter_document_ids:
|
||||||
|
return []
|
||||||
|
document_ids_filter = None
|
||||||
|
if metadata_filter_document_ids:
|
||||||
|
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||||
|
if document_ids:
|
||||||
|
document_ids_filter = document_ids
|
||||||
|
else:
|
||||||
|
return []
|
||||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||||
|
|
||||||
# get top k
|
# get top k
|
||||||
@ -324,6 +383,7 @@ class DatasetRetrieval:
|
|||||||
reranking_model=reranking_model,
|
reranking_model=reranking_model,
|
||||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||||
weights=retrieval_model_config.get("weights", None),
|
weights=retrieval_model_config.get("weights", None),
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||||
|
|
||||||
@ -348,6 +408,8 @@ class DatasetRetrieval:
|
|||||||
weights: Optional[dict[str, Any]] = None,
|
weights: Optional[dict[str, Any]] = None,
|
||||||
reranking_enable: bool = True,
|
reranking_enable: bool = True,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
|
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
|
||||||
|
metadata_condition: Optional[MetadataCondition] = None,
|
||||||
):
|
):
|
||||||
if not available_datasets:
|
if not available_datasets:
|
||||||
return []
|
return []
|
||||||
@ -387,6 +449,16 @@ class DatasetRetrieval:
|
|||||||
|
|
||||||
for dataset in available_datasets:
|
for dataset in available_datasets:
|
||||||
index_type = dataset.indexing_technique
|
index_type = dataset.indexing_technique
|
||||||
|
document_ids_filter = None
|
||||||
|
if dataset.provider != "external":
|
||||||
|
if metadata_condition and not metadata_filter_document_ids:
|
||||||
|
continue
|
||||||
|
if metadata_filter_document_ids:
|
||||||
|
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||||
|
if document_ids:
|
||||||
|
document_ids_filter = document_ids
|
||||||
|
else:
|
||||||
|
continue
|
||||||
retrieval_thread = threading.Thread(
|
retrieval_thread = threading.Thread(
|
||||||
target=self._retriever,
|
target=self._retriever,
|
||||||
kwargs={
|
kwargs={
|
||||||
@ -395,6 +467,8 @@ class DatasetRetrieval:
|
|||||||
"query": query,
|
"query": query,
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"all_documents": all_documents,
|
"all_documents": all_documents,
|
||||||
|
"document_ids_filter": document_ids_filter,
|
||||||
|
"metadata_condition": metadata_condition,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
threads.append(retrieval_thread)
|
threads.append(retrieval_thread)
|
||||||
@ -493,7 +567,16 @@ class DatasetRetrieval:
|
|||||||
db.session.add_all(dataset_queries)
|
db.session.add_all(dataset_queries)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
def _retriever(
|
||||||
|
self,
|
||||||
|
flask_app: Flask,
|
||||||
|
dataset_id: str,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
all_documents: list,
|
||||||
|
document_ids_filter: Optional[list[str]] = None,
|
||||||
|
metadata_condition: Optional[MetadataCondition] = None,
|
||||||
|
):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||||
|
|
||||||
@ -506,6 +589,7 @@ class DatasetRetrieval:
|
|||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
query=query,
|
query=query,
|
||||||
external_retrieval_parameters=dataset.retrieval_model,
|
external_retrieval_parameters=dataset.retrieval_model,
|
||||||
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
for external_document in external_documents:
|
for external_document in external_documents:
|
||||||
document = Document(
|
document = Document(
|
||||||
@ -546,6 +630,7 @@ class DatasetRetrieval:
|
|||||||
else None,
|
else None,
|
||||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||||
weights=retrieval_model.get("weights", None),
|
weights=retrieval_model.get("weights", None),
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
@ -733,3 +818,340 @@ class DatasetRetrieval:
|
|||||||
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
|
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
|
||||||
)
|
)
|
||||||
return filter_documents[:top_k] if top_k else filter_documents
|
return filter_documents[:top_k] if top_k else filter_documents
|
||||||
|
|
||||||
|
def _get_metadata_filter_condition(
|
||||||
|
self,
|
||||||
|
dataset_ids: list,
|
||||||
|
query: str,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
metadata_filtering_mode: str,
|
||||||
|
metadata_model_config: ModelConfig,
|
||||||
|
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
|
||||||
|
inputs: dict,
|
||||||
|
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
||||||
|
document_query = db.session.query(DatasetDocument).filter(
|
||||||
|
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||||
|
DatasetDocument.indexing_status == "completed",
|
||||||
|
DatasetDocument.enabled == True,
|
||||||
|
DatasetDocument.archived == False,
|
||||||
|
)
|
||||||
|
filters = [] # type: ignore
|
||||||
|
metadata_condition = None
|
||||||
|
if metadata_filtering_mode == "disabled":
|
||||||
|
return None, None
|
||||||
|
elif metadata_filtering_mode == "automatic":
|
||||||
|
automatic_metadata_filters = self._automatic_metadata_filter_func(
|
||||||
|
dataset_ids, query, tenant_id, user_id, metadata_model_config
|
||||||
|
)
|
||||||
|
if automatic_metadata_filters:
|
||||||
|
conditions = []
|
||||||
|
for filter in automatic_metadata_filters:
|
||||||
|
self._process_metadata_filter_func(
|
||||||
|
filter.get("condition"), # type: ignore
|
||||||
|
filter.get("metadata_name"), # type: ignore
|
||||||
|
filter.get("value"),
|
||||||
|
filters, # type: ignore
|
||||||
|
)
|
||||||
|
conditions.append(
|
||||||
|
Condition(
|
||||||
|
name=filter.get("metadata_name"), # type: ignore
|
||||||
|
comparison_operator=filter.get("condition"), # type: ignore
|
||||||
|
value=filter.get("value"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metadata_condition = MetadataCondition(
|
||||||
|
logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
|
||||||
|
conditions=conditions,
|
||||||
|
)
|
||||||
|
elif metadata_filtering_mode == "manual":
|
||||||
|
if metadata_filtering_conditions:
|
||||||
|
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
|
||||||
|
for condition in metadata_filtering_conditions.conditions: # type: ignore
|
||||||
|
metadata_name = condition.name
|
||||||
|
expected_value = condition.value
|
||||||
|
if expected_value or condition.comparison_operator in ("empty", "not empty"):
|
||||||
|
if isinstance(expected_value, str):
|
||||||
|
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
|
||||||
|
filters = self._process_metadata_filter_func(
|
||||||
|
condition.comparison_operator, metadata_name, expected_value, filters
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid metadata filtering mode")
|
||||||
|
if filters:
|
||||||
|
if metadata_filtering_conditions.logical_operator == "or": # type: ignore
|
||||||
|
document_query = document_query.filter(or_(*filters))
|
||||||
|
else:
|
||||||
|
document_query = document_query.filter(and_(*filters))
|
||||||
|
documents = document_query.all()
|
||||||
|
# group by dataset_id
|
||||||
|
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||||
|
for document in documents:
|
||||||
|
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||||
|
return metadata_filter_document_ids, metadata_condition
|
||||||
|
|
||||||
|
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
||||||
|
def replacer(match):
|
||||||
|
key = match.group(1)
|
||||||
|
return str(inputs.get(key, f"{{{{{key}}}}}"))
|
||||||
|
|
||||||
|
pattern = re.compile(r"\{\{(\w+)\}\}")
|
||||||
|
return pattern.sub(replacer, text)
|
||||||
|
|
||||||
|
def _automatic_metadata_filter_func(
|
||||||
|
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||||
|
) -> Optional[list[dict[str, Any]]]:
|
||||||
|
# get all metadata field
|
||||||
|
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||||
|
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||||
|
# get metadata model config
|
||||||
|
if metadata_model_config is None:
|
||||||
|
raise ValueError("metadata_model_config is required")
|
||||||
|
# get metadata model instance
|
||||||
|
# fetch model config
|
||||||
|
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
|
||||||
|
|
||||||
|
# fetch prompt messages
|
||||||
|
prompt_messages, stop = self._get_prompt_template(
|
||||||
|
model_config=model_config,
|
||||||
|
mode=metadata_model_config.mode,
|
||||||
|
metadata_fields=all_metadata_fields,
|
||||||
|
query=query or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
result_text = ""
|
||||||
|
try:
|
||||||
|
# handle invoke result
|
||||||
|
invoke_result = cast(
|
||||||
|
Generator[LLMResult, None, None],
|
||||||
|
model_instance.invoke_llm(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_config.parameters,
|
||||||
|
stop=stop,
|
||||||
|
stream=True,
|
||||||
|
user=user_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle invoke result
|
||||||
|
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||||
|
|
||||||
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||||
|
automatic_metadata_filters = []
|
||||||
|
if "metadata_map" in result_text_json:
|
||||||
|
metadata_map = result_text_json["metadata_map"]
|
||||||
|
for item in metadata_map:
|
||||||
|
if item.get("metadata_field_name") in all_metadata_fields:
|
||||||
|
automatic_metadata_filters.append(
|
||||||
|
{
|
||||||
|
"metadata_name": item.get("metadata_field_name"),
|
||||||
|
"value": item.get("metadata_field_value"),
|
||||||
|
"condition": item.get("comparison_operator"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
return automatic_metadata_filters
|
||||||
|
|
||||||
|
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list):
|
||||||
|
match condition:
|
||||||
|
case "contains":
|
||||||
|
filters.append(
|
||||||
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
|
||||||
|
)
|
||||||
|
case "not contains":
|
||||||
|
filters.append(
|
||||||
|
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
|
||||||
|
key=metadata_name, value=f"%{value}%"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case "start with":
|
||||||
|
filters.append(
|
||||||
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
|
||||||
|
)
|
||||||
|
|
||||||
|
case "end with":
|
||||||
|
filters.append(
|
||||||
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
|
||||||
|
)
|
||||||
|
case "is" | "=":
|
||||||
|
if isinstance(value, str):
|
||||||
|
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
|
||||||
|
else:
|
||||||
|
filters.append(
|
||||||
|
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
|
||||||
|
)
|
||||||
|
case "is not" | "≠":
|
||||||
|
if isinstance(value, str):
|
||||||
|
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
|
||||||
|
else:
|
||||||
|
filters.append(
|
||||||
|
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
|
||||||
|
)
|
||||||
|
case "empty":
|
||||||
|
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
|
||||||
|
case "not empty":
|
||||||
|
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
|
||||||
|
case "before" | "<":
|
||||||
|
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
|
||||||
|
case "after" | ">":
|
||||||
|
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
|
||||||
|
case "≤" | ">=":
|
||||||
|
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
|
||||||
|
case "≥" | ">=":
|
||||||
|
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
|
||||||
|
case _:
|
||||||
|
pass
|
||||||
|
return filters
|
||||||
|
|
||||||
|
def _fetch_model_config(
|
||||||
|
self, tenant_id: str, model: ModelConfig
|
||||||
|
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
|
"""
|
||||||
|
Fetch model config
|
||||||
|
:param node_data: node data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("single_retrieval_config is required")
|
||||||
|
model_name = model.name
|
||||||
|
provider_name = model.provider
|
||||||
|
|
||||||
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_model_instance(
|
||||||
|
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_model_bundle = model_instance.provider_model_bundle
|
||||||
|
model_type_instance = model_instance.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
|
model_credentials = model_instance.credentials
|
||||||
|
|
||||||
|
# check model
|
||||||
|
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||||
|
model=model_name, model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider_model is None:
|
||||||
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
|
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||||
|
raise ValueError(f"Model {model_name} credentials is not initialized.")
|
||||||
|
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||||
|
raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||||
|
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||||
|
raise ValueError(f"Model provider {provider_name} quota exceeded.")
|
||||||
|
|
||||||
|
# model config
|
||||||
|
completion_params = model.completion_params
|
||||||
|
stop = []
|
||||||
|
if "stop" in completion_params:
|
||||||
|
stop = completion_params["stop"]
|
||||||
|
del completion_params["stop"]
|
||||||
|
|
||||||
|
# get model mode
|
||||||
|
model_mode = model.mode
|
||||||
|
if not model_mode:
|
||||||
|
raise ValueError("LLM mode is required.")
|
||||||
|
|
||||||
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
|
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
|
return model_instance, ModelConfigWithCredentialsEntity(
|
||||||
|
provider=provider_name,
|
||||||
|
model=model_name,
|
||||||
|
model_schema=model_schema,
|
||||||
|
mode=model_mode,
|
||||||
|
provider_model_bundle=provider_model_bundle,
|
||||||
|
credentials=model_credentials,
|
||||||
|
parameters=completion_params,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_prompt_template(
|
||||||
|
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
|
||||||
|
):
|
||||||
|
model_mode = ModelMode.value_of(mode)
|
||||||
|
input_text = query
|
||||||
|
|
||||||
|
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||||
|
if model_mode == ModelMode.CHAT:
|
||||||
|
prompt_template = []
|
||||||
|
system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
|
||||||
|
prompt_template.append(system_prompt_messages)
|
||||||
|
user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
|
||||||
|
prompt_template.append(user_prompt_message_1)
|
||||||
|
assistant_prompt_message_1 = ChatModelMessage(
|
||||||
|
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||||
|
)
|
||||||
|
prompt_template.append(assistant_prompt_message_1)
|
||||||
|
user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
|
||||||
|
prompt_template.append(user_prompt_message_2)
|
||||||
|
assistant_prompt_message_2 = ChatModelMessage(
|
||||||
|
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||||
|
)
|
||||||
|
prompt_template.append(assistant_prompt_message_2)
|
||||||
|
user_prompt_message_3 = ChatModelMessage(
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||||
|
input_text=input_text,
|
||||||
|
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
prompt_template.append(user_prompt_message_3)
|
||||||
|
elif model_mode == ModelMode.COMPLETION:
|
||||||
|
prompt_template = CompletionModelPromptTemplate(
|
||||||
|
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||||
|
input_text=input_text,
|
||||||
|
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model mode {model_mode} not support.")
|
||||||
|
|
||||||
|
prompt_transform = AdvancedPromptTransform()
|
||||||
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
inputs={},
|
||||||
|
query=query or "",
|
||||||
|
files=[],
|
||||||
|
context=None,
|
||||||
|
memory_config=None,
|
||||||
|
memory=None,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
stop = model_config.stop
|
||||||
|
|
||||||
|
return prompt_messages, stop
|
||||||
|
|
||||||
|
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||||
|
"""
|
||||||
|
Handle invoke result
|
||||||
|
:param invoke_result: invoke result
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
model = None
|
||||||
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
full_text = ""
|
||||||
|
usage = None
|
||||||
|
for result in invoke_result:
|
||||||
|
text = result.delta.message.content
|
||||||
|
full_text += text
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
model = result.model
|
||||||
|
|
||||||
|
if not prompt_messages:
|
||||||
|
prompt_messages = result.prompt_messages
|
||||||
|
|
||||||
|
if not usage and result.delta.usage:
|
||||||
|
usage = result.delta.usage
|
||||||
|
|
||||||
|
if not usage:
|
||||||
|
usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
|
return full_text, usage
|
||||||
|
66
api/core/rag/retrieval/template_prompts.py
Normal file
66
api/core/rag/retrieval/template_prompts.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
METADATA_FILTER_SYSTEM_PROMPT = """
|
||||||
|
### Job Description',
|
||||||
|
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||||
|
### Task
|
||||||
|
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||||
|
### Format
|
||||||
|
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||||
|
### Constraint
|
||||||
|
DO NOT include anything other than the JSON array in your response.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_1 = """
|
||||||
|
{ "input_text": "I want to know which company’s email address test@example.com is?",
|
||||||
|
"metadata_fields": ["filename", "email", "phone", "address"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
|
||||||
|
```json
|
||||||
|
{"metadata_map": [
|
||||||
|
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_2 = """
|
||||||
|
{"input_text": "What are the movies with a score of more than 9 in 2024?",
|
||||||
|
"metadata_fields": ["name", "year", "rating", "country"]}
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
|
||||||
|
```json
|
||||||
|
{"metadata_map": [
|
||||||
|
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
|
||||||
|
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
|
||||||
|
]}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_3 = """
|
||||||
|
'{{"input_text": "{input_text}",',
|
||||||
|
'"metadata_fields": {metadata_fields}}}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_COMPLETION_PROMPT = """
|
||||||
|
### Job Description
|
||||||
|
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||||
|
### Task
|
||||||
|
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||||
|
### Format
|
||||||
|
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||||
|
### Constraint
|
||||||
|
DO NOT include anything other than the JSON array in your response.
|
||||||
|
### Example
|
||||||
|
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||||
|
<example>
|
||||||
|
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
|
||||||
|
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
|
||||||
|
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
|
||||||
|
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
|
||||||
|
</example>
|
||||||
|
### User Input
|
||||||
|
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
|
||||||
|
### Assistant Output
|
||||||
|
""" # noqa: E501
|
@ -1,8 +1,10 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.nodes.base import BaseNodeData
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
|
from core.workflow.nodes.llm.entities import VisionConfig
|
||||||
|
|
||||||
|
|
||||||
class RerankingModelConfig(BaseModel):
|
class RerankingModelConfig(BaseModel):
|
||||||
@ -73,6 +75,48 @@ class SingleRetrievalConfig(BaseModel):
|
|||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
SupportedComparisonOperator = Literal[
|
||||||
|
# for string or array
|
||||||
|
"contains",
|
||||||
|
"not contains",
|
||||||
|
"start with",
|
||||||
|
"end with",
|
||||||
|
"is",
|
||||||
|
"is not",
|
||||||
|
"empty",
|
||||||
|
"not empty",
|
||||||
|
# for number
|
||||||
|
"=",
|
||||||
|
"≠",
|
||||||
|
">",
|
||||||
|
"<",
|
||||||
|
"≥",
|
||||||
|
"≤",
|
||||||
|
# for time
|
||||||
|
"before",
|
||||||
|
"after",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(BaseModel):
|
||||||
|
"""
|
||||||
|
Conditon detail
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
comparison_operator: SupportedComparisonOperator
|
||||||
|
value: str | Sequence[str] | None | int | float = None
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilteringCondition(BaseModel):
|
||||||
|
"""
|
||||||
|
Metadata Filtering Condition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||||
|
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||||
"""
|
"""
|
||||||
Knowledge retrieval Node Data.
|
Knowledge retrieval Node Data.
|
||||||
@ -84,3 +128,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
|||||||
retrieval_mode: Literal["single", "multiple"]
|
retrieval_mode: Literal["single", "multiple"]
|
||||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||||
|
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||||
|
metadata_model_config: Optional[ModelConfig] = None
|
||||||
|
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||||
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||||
|
@ -16,3 +16,7 @@ class ModelNotSupportedError(KnowledgeRetrievalNodeError):
|
|||||||
|
|
||||||
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
|
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
|
||||||
"""Raised when the model provider quota is exceeded."""
|
"""Raised when the model provider quota is exceeded."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
|
||||||
|
"""Raised when the model is not a Large Language Model."""
|
||||||
|
@ -1,32 +1,51 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import Integer, and_, func, or_, text
|
||||||
|
from sqlalchemy import cast as sqlalchemy_cast
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
|
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.variables import StringSegment
|
from core.variables import StringSegment
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.nodes.base import BaseNode
|
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
||||||
|
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||||
|
METADATA_FILTER_COMPLETION_PROMPT,
|
||||||
|
METADATA_FILTER_SYSTEM_PROMPT,
|
||||||
|
METADATA_FILTER_USER_PROMPT_1,
|
||||||
|
METADATA_FILTER_USER_PROMPT_3,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset, Document, RateLimitLog
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from .entities import KnowledgeRetrievalNodeData
|
from .entities import KnowledgeRetrievalNodeData, ModelConfig
|
||||||
from .exc import (
|
from .exc import (
|
||||||
|
InvalidModelTypeError,
|
||||||
KnowledgeRetrievalNodeError,
|
KnowledgeRetrievalNodeError,
|
||||||
ModelCredentialsNotInitializedError,
|
ModelCredentialsNotInitializedError,
|
||||||
ModelNotExistError,
|
ModelNotExistError,
|
||||||
@ -45,13 +64,14 @@ default_retrieval_model = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
class KnowledgeRetrievalNode(LLMNode):
|
||||||
_node_data_cls = KnowledgeRetrievalNodeData
|
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
|
||||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult: # type: ignore
|
||||||
|
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||||
# extract variables
|
# extract variables
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
|
||||||
if not isinstance(variable, StringSegment):
|
if not isinstance(variable, StringSegment):
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
@ -91,7 +111,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|||||||
|
|
||||||
# retrieve knowledge
|
# retrieve knowledge
|
||||||
try:
|
try:
|
||||||
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||||
outputs = {"result": results}
|
outputs = {"result": results}
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||||
@ -145,11 +165,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
continue
|
continue
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
|
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
||||||
|
[dataset.id for dataset in available_datasets], query, node_data
|
||||||
|
)
|
||||||
all_documents = []
|
all_documents = []
|
||||||
dataset_retrieval = DatasetRetrieval()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = self._fetch_model_config(node_data)
|
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
|
||||||
# check model is support tool calling
|
# check model is support tool calling
|
||||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
@ -174,6 +197,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
planning_strategy=planning_strategy,
|
planning_strategy=planning_strategy,
|
||||||
|
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||||
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||||
if node_data.multiple_retrieval_config is None:
|
if node_data.multiple_retrieval_config is None:
|
||||||
@ -220,6 +245,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|||||||
reranking_model=reranking_model,
|
reranking_model=reranking_model,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||||
|
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||||
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||||
@ -287,13 +314,187 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|||||||
item["metadata"]["position"] = position
|
item["metadata"]["position"] = position
|
||||||
return retrieval_resource_list
|
return retrieval_resource_list
|
||||||
|
|
||||||
|
def _get_metadata_filter_condition(
|
||||||
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||||
|
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
|
||||||
|
document_query = db.session.query(Document).filter(
|
||||||
|
Document.dataset_id.in_(dataset_ids),
|
||||||
|
Document.indexing_status == "completed",
|
||||||
|
Document.enabled == True,
|
||||||
|
Document.archived == False,
|
||||||
|
)
|
||||||
|
filters = [] # type: ignore
|
||||||
|
metadata_condition = None
|
||||||
|
if node_data.metadata_filtering_mode == "disabled":
|
||||||
|
return None, None
|
||||||
|
elif node_data.metadata_filtering_mode == "automatic":
|
||||||
|
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
||||||
|
if automatic_metadata_filters:
|
||||||
|
conditions = []
|
||||||
|
for filter in automatic_metadata_filters:
|
||||||
|
self._process_metadata_filter_func(
|
||||||
|
filter.get("condition", ""),
|
||||||
|
filter.get("metadata_name", ""),
|
||||||
|
filter.get("value"),
|
||||||
|
filters, # type: ignore
|
||||||
|
)
|
||||||
|
conditions.append(
|
||||||
|
Condition(
|
||||||
|
name=filter.get("metadata_name"), # type: ignore
|
||||||
|
comparison_operator=filter.get("condition"), # type: ignore
|
||||||
|
value=filter.get("value"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metadata_condition = MetadataCondition(
|
||||||
|
logical_operator=node_data.metadata_filtering_conditions.logical_operator, # type: ignore
|
||||||
|
conditions=conditions,
|
||||||
|
)
|
||||||
|
elif node_data.metadata_filtering_mode == "manual":
|
||||||
|
if node_data.metadata_filtering_conditions:
|
||||||
|
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
|
||||||
|
if node_data.metadata_filtering_conditions:
|
||||||
|
for condition in node_data.metadata_filtering_conditions.conditions: # type: ignore
|
||||||
|
metadata_name = condition.name
|
||||||
|
expected_value = condition.value
|
||||||
|
if expected_value or condition.comparison_operator in ("empty", "not empty"):
|
||||||
|
if isinstance(expected_value, str):
|
||||||
|
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||||
|
expected_value
|
||||||
|
).text
|
||||||
|
|
||||||
|
filters = self._process_metadata_filter_func(
|
||||||
|
condition.comparison_operator, metadata_name, expected_value, filters
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid metadata filtering mode")
|
||||||
|
if filters:
|
||||||
|
if node_data.metadata_filtering_conditions.logical_operator == "and": # type: ignore
|
||||||
|
document_query = document_query.filter(and_(*filters))
|
||||||
|
else:
|
||||||
|
document_query = document_query.filter(or_(*filters))
|
||||||
|
documents = document_query.all()
|
||||||
|
# group by dataset_id
|
||||||
|
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||||
|
for document in documents:
|
||||||
|
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||||
|
return metadata_filter_document_ids, metadata_condition
|
||||||
|
|
||||||
|
def _automatic_metadata_filter_func(
|
||||||
|
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
# get all metadata field
|
||||||
|
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||||
|
all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
|
||||||
|
# get metadata model config
|
||||||
|
metadata_model_config = node_data.metadata_model_config
|
||||||
|
if metadata_model_config is None:
|
||||||
|
raise ValueError("metadata_model_config is required")
|
||||||
|
# get metadata model instance
|
||||||
|
# fetch model config
|
||||||
|
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
|
||||||
|
# fetch prompt messages
|
||||||
|
prompt_template = self._get_prompt_template(
|
||||||
|
node_data=node_data,
|
||||||
|
metadata_fields=all_metadata_fields,
|
||||||
|
query=query or "",
|
||||||
|
)
|
||||||
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
sys_query=query,
|
||||||
|
memory=None,
|
||||||
|
model_config=model_config,
|
||||||
|
sys_files=[],
|
||||||
|
vision_enabled=node_data.vision.enabled,
|
||||||
|
vision_detail=node_data.vision.configs.detail,
|
||||||
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
|
jinja2_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
result_text = ""
|
||||||
|
try:
|
||||||
|
# handle invoke result
|
||||||
|
generator = self._invoke_llm(
|
||||||
|
node_data_model=node_data.metadata_model_config, # type: ignore
|
||||||
|
model_instance=model_instance,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
||||||
|
for event in generator:
|
||||||
|
if isinstance(event, ModelInvokeCompletedEvent):
|
||||||
|
result_text = event.text
|
||||||
|
break
|
||||||
|
|
||||||
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||||
|
automatic_metadata_filters = []
|
||||||
|
if "metadata_map" in result_text_json:
|
||||||
|
metadata_map = result_text_json["metadata_map"]
|
||||||
|
for item in metadata_map:
|
||||||
|
if item.get("metadata_field_name") in all_metadata_fields:
|
||||||
|
automatic_metadata_filters.append(
|
||||||
|
{
|
||||||
|
"metadata_name": item.get("metadata_field_name"),
|
||||||
|
"value": item.get("metadata_field_value"),
|
||||||
|
"condition": item.get("comparison_operator"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return []
|
||||||
|
return automatic_metadata_filters
|
||||||
|
|
||||||
|
def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[str], filters: list):
|
||||||
|
match condition:
|
||||||
|
case "contains":
|
||||||
|
filters.append(
|
||||||
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
|
||||||
|
)
|
||||||
|
case "not contains":
|
||||||
|
filters.append(
|
||||||
|
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
|
||||||
|
key=metadata_name, value=f"%{value}%"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case "start with":
|
||||||
|
filters.append(
|
||||||
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
|
||||||
|
)
|
||||||
|
case "end with":
|
||||||
|
filters.append(
|
||||||
|
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
|
||||||
|
)
|
||||||
|
case "=" | "is":
|
||||||
|
if isinstance(value, str):
|
||||||
|
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
|
||||||
|
else:
|
||||||
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
|
||||||
|
case "is not" | "≠":
|
||||||
|
if isinstance(value, str):
|
||||||
|
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
|
||||||
|
else:
|
||||||
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
|
||||||
|
case "empty":
|
||||||
|
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||||
|
case "not empty":
|
||||||
|
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||||
|
case "before" | "<":
|
||||||
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
|
||||||
|
case "after" | ">":
|
||||||
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
|
||||||
|
case "≤" | ">=":
|
||||||
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
|
||||||
|
case "≥" | ">=":
|
||||||
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
|
||||||
|
case _:
|
||||||
|
pass
|
||||||
|
return filters
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
graph_config: Mapping[str, Any],
|
graph_config: Mapping[str, Any],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_data: KnowledgeRetrievalNodeData,
|
node_data: KnowledgeRetrievalNodeData, # type: ignore
|
||||||
) -> Mapping[str, Sequence[str]]:
|
) -> Mapping[str, Sequence[str]]:
|
||||||
"""
|
"""
|
||||||
Extract variable selector to variable mapping
|
Extract variable selector to variable mapping
|
||||||
@ -306,18 +507,16 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|||||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
def _fetch_model_config(
|
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
|
||||||
self, node_data: KnowledgeRetrievalNodeData
|
|
||||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
|
||||||
"""
|
"""
|
||||||
Fetch model config
|
Fetch model config
|
||||||
:param node_data: node data
|
:param model: model
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if node_data.single_retrieval_config is None:
|
if model is None:
|
||||||
raise ValueError("single_retrieval_config is required")
|
raise ValueError("model is required")
|
||||||
model_name = node_data.single_retrieval_config.model.name
|
model_name = model.name
|
||||||
provider_name = node_data.single_retrieval_config.model.provider
|
provider_name = model.provider
|
||||||
|
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_instance = model_manager.get_model_instance(
|
model_instance = model_manager.get_model_instance(
|
||||||
@ -346,14 +545,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|||||||
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||||
|
|
||||||
# model config
|
# model config
|
||||||
completion_params = node_data.single_retrieval_config.model.completion_params
|
completion_params = model.completion_params
|
||||||
stop = []
|
stop = []
|
||||||
if "stop" in completion_params:
|
if "stop" in completion_params:
|
||||||
stop = completion_params["stop"]
|
stop = completion_params["stop"]
|
||||||
del completion_params["stop"]
|
del completion_params["stop"]
|
||||||
|
|
||||||
# get model mode
|
# get model mode
|
||||||
model_mode = node_data.single_retrieval_config.model.mode
|
model_mode = model.mode
|
||||||
if not model_mode:
|
if not model_mode:
|
||||||
raise ModelNotExistError("LLM mode is required.")
|
raise ModelNotExistError("LLM mode is required.")
|
||||||
|
|
||||||
@ -372,3 +571,50 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|||||||
parameters=completion_params,
|
parameters=completion_params,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||||
|
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
|
||||||
|
input_text = query
|
||||||
|
memory_str = ""
|
||||||
|
|
||||||
|
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||||
|
if model_mode == ModelMode.CHAT:
|
||||||
|
system_prompt_messages = LLMNodeChatModelMessage(
|
||||||
|
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
prompt_messages.append(system_prompt_messages)
|
||||||
|
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||||
|
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
|
||||||
|
)
|
||||||
|
prompt_messages.append(user_prompt_message_1)
|
||||||
|
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||||
|
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||||
|
)
|
||||||
|
prompt_messages.append(assistant_prompt_message_1)
|
||||||
|
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||||
|
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||||
|
)
|
||||||
|
prompt_messages.append(user_prompt_message_2)
|
||||||
|
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||||
|
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||||
|
)
|
||||||
|
prompt_messages.append(assistant_prompt_message_2)
|
||||||
|
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||||
|
input_text=input_text,
|
||||||
|
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
prompt_messages.append(user_prompt_message_3)
|
||||||
|
return prompt_messages
|
||||||
|
elif model_mode == ModelMode.COMPLETION:
|
||||||
|
return LLMNodeCompletionModelPromptTemplate(
|
||||||
|
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||||
|
input_text=input_text,
|
||||||
|
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||||
|
@ -0,0 +1,66 @@
|
|||||||
|
METADATA_FILTER_SYSTEM_PROMPT = """
|
||||||
|
### Job Description',
|
||||||
|
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||||
|
### Task
|
||||||
|
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||||
|
### Format
|
||||||
|
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||||
|
### Constraint
|
||||||
|
DO NOT include anything other than the JSON array in your response.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_1 = """
|
||||||
|
{ "input_text": "I want to know which company’s email address test@example.com is?",
|
||||||
|
"metadata_fields": ["filename", "email", "phone", "address"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
|
||||||
|
```json
|
||||||
|
{"metadata_map": [
|
||||||
|
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_2 = """
|
||||||
|
{"input_text": "What are the movies with a score of more than 9 in 2024?",
|
||||||
|
"metadata_fields": ["name", "year", "rating", "country"]}
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
|
||||||
|
```json
|
||||||
|
{"metadata_map": [
|
||||||
|
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
|
||||||
|
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
|
||||||
|
]}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_3 = """
|
||||||
|
'{{"input_text": "{input_text}",',
|
||||||
|
'"metadata_fields": {metadata_fields}}}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_COMPLETION_PROMPT = """
|
||||||
|
### Job Description
|
||||||
|
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||||
|
### Task
|
||||||
|
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||||
|
### Format
|
||||||
|
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||||
|
### Constraint
|
||||||
|
DO NOT include anything other than the JSON array in your response.
|
||||||
|
### Example
|
||||||
|
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||||
|
<example>
|
||||||
|
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
|
||||||
|
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
|
||||||
|
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
|
||||||
|
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
|
||||||
|
</example>
|
||||||
|
### User Input
|
||||||
|
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
|
||||||
|
### Assistant Output
|
||||||
|
""" # noqa: E501
|
@ -53,6 +53,8 @@ external_knowledge_info_fields = {
|
|||||||
"external_knowledge_api_endpoint": fields.String,
|
"external_knowledge_api_endpoint": fields.String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||||
|
|
||||||
dataset_detail_fields = {
|
dataset_detail_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"name": fields.String,
|
"name": fields.String,
|
||||||
@ -76,6 +78,8 @@ dataset_detail_fields = {
|
|||||||
"doc_form": fields.String,
|
"doc_form": fields.String,
|
||||||
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
|
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
|
||||||
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
|
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
|
||||||
|
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
|
||||||
|
"built_in_field_enabled": fields.Boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_query_detail_fields = {
|
dataset_query_detail_fields = {
|
||||||
@ -87,3 +91,9 @@ dataset_query_detail_fields = {
|
|||||||
"created_by": fields.String,
|
"created_by": fields.String,
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dataset_metadata_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"type": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
}
|
||||||
|
@ -3,6 +3,13 @@ from flask_restful import fields # type: ignore
|
|||||||
from fields.dataset_fields import dataset_fields
|
from fields.dataset_fields import dataset_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
document_metadata_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"type": fields.String,
|
||||||
|
"value": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
document_fields = {
|
document_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"position": fields.Integer,
|
"position": fields.Integer,
|
||||||
@ -25,6 +32,7 @@ document_fields = {
|
|||||||
"word_count": fields.Integer,
|
"word_count": fields.Integer,
|
||||||
"hit_count": fields.Integer,
|
"hit_count": fields.Integer,
|
||||||
"doc_form": fields.String,
|
"doc_form": fields.String,
|
||||||
|
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||||
}
|
}
|
||||||
|
|
||||||
document_with_segments_fields = {
|
document_with_segments_fields = {
|
||||||
@ -51,6 +59,7 @@ document_with_segments_fields = {
|
|||||||
"hit_count": fields.Integer,
|
"hit_count": fields.Integer,
|
||||||
"completed_segments": fields.Integer,
|
"completed_segments": fields.Integer,
|
||||||
"total_segments": fields.Integer,
|
"total_segments": fields.Integer,
|
||||||
|
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset_and_document_fields = {
|
dataset_and_document_fields = {
|
||||||
|
@ -0,0 +1,90 @@
|
|||||||
|
"""add_metadata_function
|
||||||
|
|
||||||
|
Revision ID: d20049ed0af6
|
||||||
|
Revises: 08ec4f75af5e
|
||||||
|
Create Date: 2025-02-27 09:17:48.903213
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'd20049ed0af6'
|
||||||
|
down_revision = 'f051706725cc'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('dataset_metadata_bindings',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('metadata_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('dataset_metadata_binding_dataset_idx', ['dataset_id'], unique=False)
|
||||||
|
batch_op.create_index('dataset_metadata_binding_document_idx', ['document_id'], unique=False)
|
||||||
|
batch_op.create_index('dataset_metadata_binding_metadata_idx', ['metadata_id'], unique=False)
|
||||||
|
batch_op.create_index('dataset_metadata_binding_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('dataset_metadatas',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('type', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('dataset_metadata_dataset_idx', ['dataset_id'], unique=False)
|
||||||
|
batch_op.create_index('dataset_metadata_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('built_in_field_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||||
|
|
||||||
|
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('doc_metadata',
|
||||||
|
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||||
|
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
existing_nullable=True)
|
||||||
|
batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('document_metadata_idx', postgresql_using='gin')
|
||||||
|
batch_op.alter_column('doc_metadata',
|
||||||
|
existing_type=postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
type_=postgresql.JSON(astext_type=sa.Text()),
|
||||||
|
existing_nullable=True)
|
||||||
|
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('built_in_field_enabled')
|
||||||
|
|
||||||
|
with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('dataset_metadata_tenant_idx')
|
||||||
|
batch_op.drop_index('dataset_metadata_dataset_idx')
|
||||||
|
|
||||||
|
op.drop_table('dataset_metadatas')
|
||||||
|
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('dataset_metadata_binding_tenant_idx')
|
||||||
|
batch_op.drop_index('dataset_metadata_binding_metadata_idx')
|
||||||
|
batch_op.drop_index('dataset_metadata_binding_document_idx')
|
||||||
|
batch_op.drop_index('dataset_metadata_binding_dataset_idx')
|
||||||
|
|
||||||
|
op.drop_table('dataset_metadata_bindings')
|
||||||
|
# ### end Alembic commands ###
|
@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import JSONB
|
|||||||
from sqlalchemy.orm import Mapped
|
from sqlalchemy.orm import Mapped
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||||
@ -60,6 +61,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
|||||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||||
collection_binding_id = db.Column(StringUUID, nullable=True)
|
collection_binding_id = db.Column(StringUUID, nullable=True)
|
||||||
retrieval_model = db.Column(JSONB, nullable=True)
|
retrieval_model = db.Column(JSONB, nullable=True)
|
||||||
|
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset_keyword_table(self):
|
def dataset_keyword_table(self):
|
||||||
@ -197,6 +199,56 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
|||||||
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
|
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def doc_metadata(self):
|
||||||
|
dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all()
|
||||||
|
|
||||||
|
doc_metadata = [
|
||||||
|
{
|
||||||
|
"id": dataset_metadata.id,
|
||||||
|
"name": dataset_metadata.name,
|
||||||
|
"type": dataset_metadata.type,
|
||||||
|
}
|
||||||
|
for dataset_metadata in dataset_metadatas
|
||||||
|
]
|
||||||
|
if self.built_in_field_enabled:
|
||||||
|
doc_metadata.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.document_name.value,
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
doc_metadata.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.uploader.value,
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
doc_metadata.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.upload_date.value,
|
||||||
|
"type": "time",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
doc_metadata.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.last_update_date.value,
|
||||||
|
"type": "time",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
doc_metadata.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.source.value,
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return doc_metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gen_collection_name_by_id(dataset_id: str) -> str:
|
def gen_collection_name_by_id(dataset_id: str) -> str:
|
||||||
normalized_dataset_id = dataset_id.replace("-", "_")
|
normalized_dataset_id = dataset_id.replace("-", "_")
|
||||||
@ -250,6 +302,7 @@ class Document(db.Model): # type: ignore[name-defined]
|
|||||||
db.Index("document_dataset_id_idx", "dataset_id"),
|
db.Index("document_dataset_id_idx", "dataset_id"),
|
||||||
db.Index("document_is_paused_idx", "is_paused"),
|
db.Index("document_is_paused_idx", "is_paused"),
|
||||||
db.Index("document_tenant_idx", "tenant_id"),
|
db.Index("document_tenant_idx", "tenant_id"),
|
||||||
|
db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# initial fields
|
# initial fields
|
||||||
@ -306,7 +359,7 @@ class Document(db.Model): # type: ignore[name-defined]
|
|||||||
archived_at = db.Column(db.DateTime, nullable=True)
|
archived_at = db.Column(db.DateTime, nullable=True)
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
doc_type = db.Column(db.String(40), nullable=True)
|
doc_type = db.Column(db.String(40), nullable=True)
|
||||||
doc_metadata = db.Column(db.JSON, nullable=True)
|
doc_metadata = db.Column(JSONB, nullable=True)
|
||||||
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||||
doc_language = db.Column(db.String(255), nullable=True)
|
doc_language = db.Column(db.String(255), nullable=True)
|
||||||
|
|
||||||
@ -396,12 +449,95 @@ class Document(db.Model): # type: ignore[name-defined]
|
|||||||
.scalar()
|
.scalar()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uploader(self):
|
||||||
|
user = db.session.query(Account).filter(Account.id == self.created_by).first()
|
||||||
|
return user.name if user else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def upload_date(self):
|
||||||
|
return self.created_at
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_update_date(self):
|
||||||
|
return self.updated_at
|
||||||
|
|
||||||
|
@property
|
||||||
|
def doc_metadata_details(self):
|
||||||
|
if self.doc_metadata:
|
||||||
|
document_metadatas = (
|
||||||
|
db.session.query(DatasetMetadata)
|
||||||
|
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
|
||||||
|
.filter(
|
||||||
|
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
metadata_list = []
|
||||||
|
for metadata in document_metadatas:
|
||||||
|
metadata_dict = {
|
||||||
|
"id": metadata.id,
|
||||||
|
"name": metadata.name,
|
||||||
|
"type": metadata.type,
|
||||||
|
"value": self.doc_metadata.get(metadata.name),
|
||||||
|
}
|
||||||
|
metadata_list.append(metadata_dict)
|
||||||
|
# deal built-in fields
|
||||||
|
metadata_list.extend(self.get_built_in_fields())
|
||||||
|
|
||||||
|
return metadata_list
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def process_rule_dict(self):
|
def process_rule_dict(self):
|
||||||
if self.dataset_process_rule_id:
|
if self.dataset_process_rule_id:
|
||||||
return self.dataset_process_rule.to_dict()
|
return self.dataset_process_rule.to_dict()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_built_in_fields(self):
|
||||||
|
built_in_fields = []
|
||||||
|
built_in_fields.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.document_name,
|
||||||
|
"type": "string",
|
||||||
|
"value": self.name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
built_in_fields.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.uploader,
|
||||||
|
"type": "string",
|
||||||
|
"value": self.uploader,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
built_in_fields.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.upload_date,
|
||||||
|
"type": "time",
|
||||||
|
"value": self.created_at.timestamp(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
built_in_fields.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.last_update_date,
|
||||||
|
"type": "time",
|
||||||
|
"value": self.updated_at.timestamp(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
built_in_fields.append(
|
||||||
|
{
|
||||||
|
"id": "built-in",
|
||||||
|
"name": BuiltInField.source,
|
||||||
|
"type": "string",
|
||||||
|
"value": MetadataDataSource[self.data_source_type].value,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return built_in_fields
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
@ -945,3 +1081,41 @@ class RateLimitLog(db.Model): # type: ignore[name-defined]
|
|||||||
subscription_plan = db.Column(db.String(255), nullable=False)
|
subscription_plan = db.Column(db.String(255), nullable=False)
|
||||||
operation = db.Column(db.String(255), nullable=False)
|
operation = db.Column(db.String(255), nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMetadata(db.Model): # type: ignore[name-defined]
|
||||||
|
__tablename__ = "dataset_metadatas"
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
|
||||||
|
db.Index("dataset_metadata_tenant_idx", "tenant_id"),
|
||||||
|
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
tenant_id = db.Column(StringUUID, nullable=False)
|
||||||
|
dataset_id = db.Column(StringUUID, nullable=False)
|
||||||
|
type = db.Column(db.String(255), nullable=False)
|
||||||
|
name = db.Column(db.String(255), nullable=False)
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
|
updated_by = db.Column(StringUUID, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
|
||||||
|
__tablename__ = "dataset_metadata_bindings"
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
|
||||||
|
db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
|
||||||
|
db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
|
||||||
|
db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
|
||||||
|
db.Index("dataset_metadata_binding_document_idx", "document_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
tenant_id = db.Column(StringUUID, nullable=False)
|
||||||
|
dataset_id = db.Column(StringUUID, nullable=False)
|
||||||
|
metadata_id = db.Column(StringUUID, nullable=False)
|
||||||
|
document_id = db.Column(StringUUID, nullable=False)
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
|
943
api/poetry.lock
generated
943
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -17,6 +18,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
|||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from events.dataset_event import dataset_was_deleted
|
from events.dataset_event import dataset_was_deleted
|
||||||
@ -643,9 +645,45 @@ class DocumentService:
|
|||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
|
||||||
|
documents = (
|
||||||
|
db.session.query(Document)
|
||||||
|
.filter(
|
||||||
|
Document.id.in_(document_ids),
|
||||||
|
Document.enabled == True,
|
||||||
|
Document.indexing_status == "completed",
|
||||||
|
Document.archived == False,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return documents
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
|
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all()
|
documents = (
|
||||||
|
db.session.query(Document)
|
||||||
|
.filter(
|
||||||
|
Document.dataset_id == dataset_id,
|
||||||
|
Document.enabled == True,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||||
|
documents = (
|
||||||
|
db.session.query(Document)
|
||||||
|
.filter(
|
||||||
|
Document.dataset_id == dataset_id,
|
||||||
|
Document.enabled == True,
|
||||||
|
Document.indexing_status == "completed",
|
||||||
|
Document.archived == False,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
@ -728,8 +766,13 @@ class DocumentService:
|
|||||||
if document.tenant_id != current_user.current_tenant_id:
|
if document.tenant_id != current_user.current_tenant_id:
|
||||||
raise ValueError("No permission.")
|
raise ValueError("No permission.")
|
||||||
|
|
||||||
document.name = name
|
if dataset.built_in_field_enabled:
|
||||||
|
if document.doc_metadata:
|
||||||
|
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||||
|
doc_metadata[BuiltInField.document_name.value] = name
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
|
||||||
|
document.name = name
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@ -1128,9 +1171,20 @@ class DocumentService:
|
|||||||
doc_form=document_form,
|
doc_form=document_form,
|
||||||
doc_language=document_language,
|
doc_language=document_language,
|
||||||
)
|
)
|
||||||
|
doc_metadata = {}
|
||||||
|
if dataset.built_in_field_enabled:
|
||||||
|
doc_metadata = {
|
||||||
|
BuiltInField.document_name: name,
|
||||||
|
BuiltInField.uploader: account.name,
|
||||||
|
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.source: data_source_type,
|
||||||
|
}
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
document.doc_metadata = metadata.doc_metadata
|
doc_metadata.update(metadata.doc_metadata)
|
||||||
document.doc_type = metadata.doc_type
|
document.doc_type = metadata.doc_type
|
||||||
|
if doc_metadata:
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
return document
|
return document
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -125,3 +125,36 @@ class SegmentUpdateArgs(BaseModel):
|
|||||||
class ChildChunkUpdateArgs(BaseModel):
|
class ChildChunkUpdateArgs(BaseModel):
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataArgs(BaseModel):
|
||||||
|
type: Literal["string", "number", "time"]
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataUpdateArgs(BaseModel):
|
||||||
|
name: str
|
||||||
|
value: Optional[str | int | float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataValueUpdateArgs(BaseModel):
|
||||||
|
fields: list[MetadataUpdateArgs]
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataDetail(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
value: Optional[str | int | float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentMetadataOperation(BaseModel):
|
||||||
|
document_id: str
|
||||||
|
metadata_list: list[MetadataDetail]
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataOperationData(BaseModel):
|
||||||
|
"""
|
||||||
|
Metadata operation data
|
||||||
|
"""
|
||||||
|
|
||||||
|
operation_data: list[DocumentMetadataOperation]
|
||||||
|
@ -8,6 +8,7 @@ import validators
|
|||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
|
from core.rag.entities.metadata_entities import MetadataCondition
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import (
|
from models.dataset import (
|
||||||
Dataset,
|
Dataset,
|
||||||
@ -245,7 +246,11 @@ class ExternalDatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fetch_external_knowledge_retrieval(
|
def fetch_external_knowledge_retrieval(
|
||||||
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
|
tenant_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
query: str,
|
||||||
|
external_retrieval_parameters: dict,
|
||||||
|
metadata_condition: Optional[MetadataCondition] = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||||
dataset_id=dataset_id, tenant_id=tenant_id
|
dataset_id=dataset_id, tenant_id=tenant_id
|
||||||
@ -272,6 +277,7 @@ class ExternalDatasetService:
|
|||||||
},
|
},
|
||||||
"query": query,
|
"query": query,
|
||||||
"knowledge_id": external_knowledge_binding.external_knowledge_id,
|
"knowledge_id": external_knowledge_binding.external_knowledge_id,
|
||||||
|
"metadata_condition": metadata_condition.model_dump() if metadata_condition else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = ExternalDatasetService.process_external_api(
|
response = ExternalDatasetService.process_external_api(
|
||||||
|
241
api/services/metadata_service.py
Normal file
241
api/services/metadata_service.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
import copy
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
|
||||||
|
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
|
||||||
|
from services.dataset_service import DocumentService
|
||||||
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
|
MetadataArgs,
|
||||||
|
MetadataOperationData,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataService:
|
||||||
|
@staticmethod
|
||||||
|
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
|
||||||
|
# check if metadata name already exists
|
||||||
|
if DatasetMetadata.query.filter_by(
|
||||||
|
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
|
||||||
|
).first():
|
||||||
|
raise ValueError("Metadata name already exists.")
|
||||||
|
for field in BuiltInField:
|
||||||
|
if field.value == metadata_args.name:
|
||||||
|
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||||
|
metadata = DatasetMetadata(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
type=metadata_args.type,
|
||||||
|
name=metadata_args.name,
|
||||||
|
created_by=current_user.id,
|
||||||
|
)
|
||||||
|
db.session.add(metadata)
|
||||||
|
db.session.commit()
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
|
||||||
|
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||||
|
# check if metadata name already exists
|
||||||
|
if DatasetMetadata.query.filter_by(
|
||||||
|
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
|
||||||
|
).first():
|
||||||
|
raise ValueError("Metadata name already exists.")
|
||||||
|
for field in BuiltInField:
|
||||||
|
if field.value == name:
|
||||||
|
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||||
|
try:
|
||||||
|
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||||
|
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||||
|
if metadata is None:
|
||||||
|
raise ValueError("Metadata not found.")
|
||||||
|
old_name = metadata.name
|
||||||
|
metadata.name = name
|
||||||
|
metadata.updated_by = current_user.id
|
||||||
|
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
|
# update related documents
|
||||||
|
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||||
|
if dataset_metadata_bindings:
|
||||||
|
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||||
|
documents = DocumentService.get_document_by_ids(document_ids)
|
||||||
|
for document in documents:
|
||||||
|
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||||
|
value = doc_metadata.pop(old_name, None)
|
||||||
|
doc_metadata[name] = value
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
return metadata # type: ignore
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Update metadata name failed")
|
||||||
|
finally:
|
||||||
|
redis_client.delete(lock_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_metadata(dataset_id: str, metadata_id: str):
|
||||||
|
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||||
|
try:
|
||||||
|
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||||
|
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||||
|
if metadata is None:
|
||||||
|
raise ValueError("Metadata not found.")
|
||||||
|
db.session.delete(metadata)
|
||||||
|
|
||||||
|
# deal related documents
|
||||||
|
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||||
|
if dataset_metadata_bindings:
|
||||||
|
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||||
|
documents = DocumentService.get_document_by_ids(document_ids)
|
||||||
|
for document in documents:
|
||||||
|
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||||
|
doc_metadata.pop(metadata.name, None)
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
return metadata
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Delete metadata failed")
|
||||||
|
finally:
|
||||||
|
redis_client.delete(lock_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_built_in_fields():
|
||||||
|
return [
|
||||||
|
{"name": BuiltInField.document_name.value, "type": "string"},
|
||||||
|
{"name": BuiltInField.uploader.value, "type": "string"},
|
||||||
|
{"name": BuiltInField.upload_date.value, "type": "time"},
|
||||||
|
{"name": BuiltInField.last_update_date.value, "type": "time"},
|
||||||
|
{"name": BuiltInField.source.value, "type": "string"},
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def enable_built_in_field(dataset: Dataset):
|
||||||
|
if dataset.built_in_field_enabled:
|
||||||
|
return
|
||||||
|
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||||
|
try:
|
||||||
|
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||||
|
dataset.built_in_field_enabled = True
|
||||||
|
db.session.add(dataset)
|
||||||
|
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||||
|
if documents:
|
||||||
|
for document in documents:
|
||||||
|
if not document.doc_metadata:
|
||||||
|
doc_metadata = {}
|
||||||
|
else:
|
||||||
|
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||||
|
doc_metadata[BuiltInField.document_name.value] = document.name
|
||||||
|
doc_metadata[BuiltInField.uploader.value] = document.uploader
|
||||||
|
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
|
||||||
|
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
|
||||||
|
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Enable built-in field failed")
|
||||||
|
finally:
|
||||||
|
redis_client.delete(lock_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def disable_built_in_field(dataset: Dataset):
|
||||||
|
if not dataset.built_in_field_enabled:
|
||||||
|
return
|
||||||
|
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||||
|
try:
|
||||||
|
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||||
|
dataset.built_in_field_enabled = False
|
||||||
|
db.session.add(dataset)
|
||||||
|
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||||
|
document_ids = []
|
||||||
|
if documents:
|
||||||
|
for document in documents:
|
||||||
|
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||||
|
doc_metadata.pop(BuiltInField.document_name.value, None)
|
||||||
|
doc_metadata.pop(BuiltInField.uploader.value, None)
|
||||||
|
doc_metadata.pop(BuiltInField.upload_date.value, None)
|
||||||
|
doc_metadata.pop(BuiltInField.last_update_date.value, None)
|
||||||
|
doc_metadata.pop(BuiltInField.source.value, None)
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
db.session.add(document)
|
||||||
|
document_ids.append(document.id)
|
||||||
|
db.session.commit()
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Disable built-in field failed")
|
||||||
|
finally:
|
||||||
|
redis_client.delete(lock_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
|
||||||
|
for operation in metadata_args.operation_data:
|
||||||
|
lock_key = f"document_metadata_lock_{operation.document_id}"
|
||||||
|
try:
|
||||||
|
MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id)
|
||||||
|
document = DocumentService.get_document(dataset.id, operation.document_id)
|
||||||
|
if document is None:
|
||||||
|
raise ValueError("Document not found.")
|
||||||
|
doc_metadata = {}
|
||||||
|
for metadata_value in operation.metadata_list:
|
||||||
|
doc_metadata[metadata_value.name] = metadata_value.value
|
||||||
|
if dataset.built_in_field_enabled:
|
||||||
|
doc_metadata[BuiltInField.document_name.value] = document.name
|
||||||
|
doc_metadata[BuiltInField.uploader.value] = document.uploader
|
||||||
|
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
|
||||||
|
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
|
||||||
|
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
# deal metadata binding
|
||||||
|
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
|
||||||
|
for metadata_value in operation.metadata_list:
|
||||||
|
dataset_metadata_binding = DatasetMetadataBinding(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
document_id=operation.document_id,
|
||||||
|
metadata_id=metadata_value.id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
)
|
||||||
|
db.session.add(dataset_metadata_binding)
|
||||||
|
db.session.commit()
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Update documents metadata failed")
|
||||||
|
finally:
|
||||||
|
redis_client.delete(lock_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
|
||||||
|
if dataset_id:
|
||||||
|
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||||
|
if redis_client.get(lock_key):
|
||||||
|
raise ValueError("Another knowledge base metadata operation is running, please wait a moment.")
|
||||||
|
redis_client.set(lock_key, 1, ex=3600)
|
||||||
|
if document_id:
|
||||||
|
lock_key = f"document_metadata_lock_{document_id}"
|
||||||
|
if redis_client.get(lock_key):
|
||||||
|
raise ValueError("Another document metadata operation is running, please wait a moment.")
|
||||||
|
redis_client.set(lock_key, 1, ex=3600)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dataset_metadatas(dataset: Dataset):
|
||||||
|
return {
|
||||||
|
"doc_metadata": [
|
||||||
|
{
|
||||||
|
"id": item.get("id"),
|
||||||
|
"name": item.get("name"),
|
||||||
|
"type": item.get("type"),
|
||||||
|
"count": DatasetMetadataBinding.query.filter_by(
|
||||||
|
metadata_id=item.get("id"), dataset_id=dataset.id
|
||||||
|
).count(),
|
||||||
|
}
|
||||||
|
for item in dataset.doc_metadata or []
|
||||||
|
if item.get("id") != "built-in"
|
||||||
|
],
|
||||||
|
"built_in_field_enabled": dataset.built_in_field_enabled,
|
||||||
|
}
|
@ -20,7 +20,7 @@ class TagService:
|
|||||||
)
|
)
|
||||||
if keyword:
|
if keyword:
|
||||||
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
query = query.group_by(Tag.id, Tag.type, Tag.name)
|
||||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user