Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Jyong 2024-12-26 00:16:35 +08:00 committed by GitHub
parent bb35818976
commit 84ac004772
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 264 additions and 210 deletions

View File

@ -587,7 +587,7 @@ def upgrade_db():
click.echo(click.style("Starting database migration.", fg="green"))
# run db migration
import flask_migrate
import flask_migrate # type: ignore
flask_migrate.upgrade()

View File

@ -413,7 +413,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
[extract_setting],
data_process_rule_dict,
@ -421,6 +421,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
"English",
dataset_id,
)
return estimate_response.model_dump(), 200
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
@ -431,7 +432,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
except Exception as e:
raise IndexingEstimateError(str(e))
return response.model_dump(), 200
return response, 200
class DocumentBatchIndexingEstimateApi(DocumentResource):
@ -521,6 +522,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"English",
dataset_id,
)
return response.model_dump(), 200
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
@ -530,7 +532,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))
return response.model_dump(), 200
class DocumentBatchIndexingStatusApi(DocumentResource):

View File

@ -22,6 +22,7 @@ from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService
@ -67,13 +68,14 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
# validate args
DocumentService.document_create_args_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
@ -122,12 +124,13 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
@ -186,12 +189,13 @@ class DocumentAddByFileApi(DatasetApiResource):
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source
# validate args
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
@ -245,12 +249,14 @@ class DocumentUpdateByFileApi(DatasetApiResource):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",

View File

@ -276,7 +276,7 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts = []
preview_texts = [] # type: ignore
total_segments = 0
index_type = doc_form
@ -300,13 +300,13 @@ class IndexingRunner:
if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer")
question=document.page_content, answer=document.metadata.get("answer") or ""
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content)
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
preview_texts.append(preview_detail)
# delete image files and related db records
@ -325,7 +325,7 @@ class IndexingRunner:
if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@ -454,7 +454,7 @@ class IndexingRunner:
embedding_model_instance=embedding_model_instance,
)
return character_splitter
return character_splitter # type: ignore
def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
@ -535,7 +535,7 @@ class IndexingRunner:
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()

View File

@ -258,78 +258,79 @@ class RetrievalService:
include_segment_ids = []
segment_child_map = {}
for document in documents:
document_id = document.metadata["document_id"]
document_id = document.metadata.get("document_id")
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata["doc_id"]
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata.get("doc_id")
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
)
.first()
)
if result:
child_chunk, segment = result
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
index_node_id = document.metadata["doc_id"]
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
records.append(record)
records.append(record)
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)

View File

@ -122,26 +122,27 @@ class DatasetDocumentStore:
db.session.add(segment_document)
db.session.flush()
if save_child:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
if doc.children:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else:
segment_document.content = doc.page_content
if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "")
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
if save_child and doc.children:
@ -160,8 +161,8 @@ class DatasetDocumentStore:
document_id=self._document_id,
segment_id=segment_document.id,
position=position,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",

View File

@ -4,7 +4,7 @@ import os
from typing import Optional, cast
import pandas as pd
from openpyxl import load_workbook
from openpyxl import load_workbook # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@ -81,4 +81,4 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance,
)
return character_splitter
return character_splitter # type: ignore

View File

@ -30,12 +30,18 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
else:
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
# Split the text documents into nodes.
if not rules.segmentation:
raise ValueError("No segmentation found in rules.")
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,

View File

@ -30,8 +30,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
all_documents = []
all_documents = [] # type: ignore
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
splitter = self._get_splitter(
@ -161,6 +165,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance],
) -> list[ChildDocument]:
if not rules.subchunk_segmentation:
raise ValueError("No subchunk segmentation found in rules.")
child_splitter = self._get_splitter(
processing_rule_mode=process_rule_mode,
max_tokens=rules.subchunk_segmentation.max_tokens,

View File

@ -37,12 +37,16 @@ class QAIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0,
separator=rules.segmentation.separator if rules.segmentation else "",
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
@ -71,8 +75,8 @@ class QAIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
if preview:
self._format_qa_document(
current_app._get_current_object(),
kwargs.get("tenant_id"),
current_app._get_current_object(), # type: ignore
kwargs.get("tenant_id"), # type: ignore
all_documents[0],
all_qa_documents,
kwargs.get("doc_language", "English"),
@ -85,8 +89,8 @@ class QAIndexProcessor(BaseIndexProcessor):
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"tenant_id": kwargs.get("tenant_id"),
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": kwargs.get("tenant_id"), # type: ignore
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel
class ChildDocument(BaseModel):
@ -15,7 +15,7 @@ class ChildDocument(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)
metadata: dict = {}
class Document(BaseModel):
@ -28,7 +28,7 @@ class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)
metadata: dict = {}
provider: Optional[str] = "dify"

View File

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
# register blueprint routers
from flask_cors import CORS
from flask_cors import CORS # type: ignore
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp

View File

@ -1,9 +1,9 @@
import logging
import time
from collections import defaultdict
import click
from celery import shared_task # type: ignore
from flask import render_template
from extensions.ext_mail import mail
from models.account import Account, Tenant, TenantAccountJoin
@ -27,7 +27,7 @@ def send_document_clean_notify_task():
try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
# group by tenant_id
dataset_auto_disable_logs_map = {}
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
@ -37,11 +37,13 @@ def send_document_clean_notify_task():
if not tenant:
continue
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
if not current_owner_join:
continue
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
if not account:
continue
dataset_auto_dataset_map = {}
dataset_auto_dataset_map = {} # type: ignore
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
@ -53,14 +55,9 @@ def send_document_clean_notify_task():
document_count = len(document_ids)
knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")
html_content = render_template(
"clean_document_job_mail_template-US.html",
)
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
end_at = time.perf_counter()
logging.info(
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
)
except Exception:
logging.exception("Send invite member mail to {} failed".format(to))
logging.exception("Send invite member mail to failed")

View File

@ -4,7 +4,7 @@ from enum import StrEnum
from typing import Optional, cast
from uuid import uuid4
import yaml
import yaml # type: ignore
from packaging import version
from pydantic import BaseModel
from sqlalchemy import select
@ -465,7 +465,7 @@ class AppDslService:
else:
cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
@classmethod
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:

View File

@ -41,6 +41,7 @@ from models.source import DataSourceOauthBinding
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
RerankingModel,
RetrievalModel,
SegmentUpdateArgs,
)
@ -548,12 +549,14 @@ class DocumentService:
}
@staticmethod
def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
return document
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
if document_id:
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
return document
else:
return None
@staticmethod
def get_document_by_id(document_id: str) -> Optional[Document]:
@ -744,25 +747,26 @@ class DocumentService:
if features.billing.enabled:
if not knowledge_config.original_document_id:
count = 0
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list: # type: ignore
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
DocumentService.check_documents_upload_quota(count, features)
DocumentService.check_documents_upload_quota(count, features)
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
@ -789,7 +793,7 @@ class DocumentService:
"score_threshold_enabled": False,
}
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
documents = []
if knowledge_config.original_document_id:
@ -801,34 +805,35 @@ class DocumentService:
# save process rule
if not dataset_process_rule:
process_rule = knowledge_config.process_rule
if process_rule.mode in ("custom", "hierarchical"):
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json(),
created_by=account.id,
)
elif process_rule.mode == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
@ -854,7 +859,7 @@ class DocumentService:
name=file_name,
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
@ -868,7 +873,7 @@ class DocumentService:
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
@ -886,6 +891,8 @@ class DocumentService:
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
@ -921,7 +928,7 @@ class DocumentService:
}
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
@ -944,6 +951,8 @@ class DocumentService:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
@ -959,7 +968,7 @@ class DocumentService:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
@ -1054,7 +1063,7 @@ class DocumentService:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json(),
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
@ -1073,6 +1082,8 @@ class DocumentService:
file_name = ""
data_source_info = {}
if document_data.data_source.info_list.data_source_type == "upload_file":
if not document_data.data_source.info_list.file_info_list:
raise ValueError("No file info list found.")
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
@ -1090,6 +1101,8 @@ class DocumentService:
"upload_file_id": file_id,
}
elif document_data.data_source.info_list.data_source_type == "notion_import":
if not document_data.data_source.info_list.notion_info_list:
raise ValueError("No notion info list found.")
notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
@ -1107,20 +1120,21 @@ class DocumentService:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type,
}
elif document_data.data_source.info_list.data_source_type == "website_crawl":
website_info = document_data.data_source.info_list.website_info_list
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if website_info:
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content, # type: ignore
"mode": "crawl",
}
document.data_source_type = document_data.data_source.info_list.data_source_type
document.data_source_info = json.dumps(data_source_info)
document.name = file_name
@ -1155,15 +1169,21 @@ class DocumentService:
if features.billing.enabled:
count = 0
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
upload_file_list = (
knowledge_config.data_source.info_list.file_info_list.file_ids
if knowledge_config.data_source.info_list.file_info_list
else []
)
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
if notion_info_list:
for notion_info in notion_info_list:
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls)
if website_info:
count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@ -1174,20 +1194,20 @@ class DocumentService:
retrieval_model = None
if knowledge_config.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_config.embedding_model_provider, knowledge_config.embedding_model
knowledge_config.embedding_model_provider, # type: ignore
knowledge_config.embedding_model, # type: ignore
)
dataset_collection_binding_id = dataset_collection_binding.id
if knowledge_config.retrieval_model:
retrieval_model = knowledge_config.retrieval_model
else:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
retrieval_model = RetrievalModel(**default_retrieval_model)
retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
top_k=2,
score_threshold_enabled=False,
)
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
@ -1557,12 +1577,12 @@ class SegmentService:
raise ValueError("Can't update disabled segment")
try:
word_count_change = segment.word_count
content = args.content
content = args.content or segment.content
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer)
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
if args.keywords:
segment.keywords = args.keywords
@ -1577,7 +1597,12 @@ class SegmentService:
db.session.add(document)
# update segment index task
if args.enabled:
VectorService.create_segments_vector([args.keywords], [segment], dataset)
VectorService.create_segments_vector(
[args.keywords] if args.keywords else None,
[segment],
dataset,
document.doc_form,
)
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
@ -1605,6 +1630,8 @@ class SegmentService:
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
@ -1639,7 +1666,7 @@ class SegmentService:
segment.disabled_by = None
if document.doc_form == "qa_model":
segment.answer = args.answer
segment.word_count += len(args.answer)
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
@ -1673,6 +1700,8 @@ class SegmentService:
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)

View File

@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
data_source: DataSource
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
doc_form: str = "text_model"

View File

@ -69,7 +69,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents)
return cls.compact_retrieve_response(query, all_documents) # type: ignore
@classmethod
def external_retrieve(

View File

@ -29,6 +29,8 @@ class VectorService:
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
@ -98,7 +100,7 @@ class VectorService:
def generate_child_chunks(
cls,
segment: DocumentSegment,
dataset_document: Document,
dataset_document: DatasetDocument,
dataset: Dataset,
embedding_model_instance: ModelInstance,
processing_rule: DatasetProcessRule,
@ -130,7 +132,7 @@ class VectorService:
doc_language=dataset_document.doc_language,
)
# save child chunks
if len(documents) > 0 and len(documents[0].children) > 0:
if documents and documents[0].children:
index_processor.load(dataset, documents)
for position, child_chunk in enumerate(documents[0].children, start=1):

View File

@ -44,7 +44,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
storage.delete(image_file.key)
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed when storage deleted, \