mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 01:18:58 +08:00
py lint (#12102)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
bb35818976
commit
84ac004772
@ -587,7 +587,7 @@ def upgrade_db():
|
||||
click.echo(click.style("Starting database migration.", fg="green"))
|
||||
|
||||
# run db migration
|
||||
import flask_migrate
|
||||
import flask_migrate # type: ignore
|
||||
|
||||
flask_migrate.upgrade()
|
||||
|
||||
|
@ -413,7 +413,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
estimate_response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
[extract_setting],
|
||||
data_process_rule_dict,
|
||||
@ -421,6 +421,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
"English",
|
||||
dataset_id,
|
||||
)
|
||||
return estimate_response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@ -431,7 +432,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
except Exception as e:
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
return response.model_dump(), 200
|
||||
return response, 200
|
||||
|
||||
|
||||
class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@ -521,6 +522,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
"English",
|
||||
dataset_id,
|
||||
)
|
||||
return response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@ -530,7 +532,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
raise IndexingEstimateError(str(e))
|
||||
return response.model_dump(), 200
|
||||
|
||||
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
|
@ -22,6 +22,7 @@ from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_service import DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -67,13 +68,14 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
@ -122,12 +124,13 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
DocumentService.document_create_args_validate(args)
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
@ -186,12 +189,13 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
@ -245,12 +249,14 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
|
@ -276,7 +276,7 @@ class IndexingRunner:
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
preview_texts = []
|
||||
preview_texts = [] # type: ignore
|
||||
|
||||
total_segments = 0
|
||||
index_type = doc_form
|
||||
@ -300,13 +300,13 @@ class IndexingRunner:
|
||||
if len(preview_texts) < 10:
|
||||
if doc_form and doc_form == "qa_model":
|
||||
preview_detail = QAPreviewDetail(
|
||||
question=document.page_content, answer=document.metadata.get("answer")
|
||||
question=document.page_content, answer=document.metadata.get("answer") or ""
|
||||
)
|
||||
preview_texts.append(preview_detail)
|
||||
else:
|
||||
preview_detail = PreviewDetail(content=document.page_content)
|
||||
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
|
||||
if document.children:
|
||||
preview_detail.child_chunks = [child.page_content for child in document.children]
|
||||
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
|
||||
preview_texts.append(preview_detail)
|
||||
|
||||
# delete image files and related db records
|
||||
@ -325,7 +325,7 @@ class IndexingRunner:
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
|
||||
|
||||
def _extract(
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||
@ -454,7 +454,7 @@ class IndexingRunner:
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
)
|
||||
|
||||
return character_splitter
|
||||
return character_splitter # type: ignore
|
||||
|
||||
def _split_to_documents_for_estimate(
|
||||
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
|
||||
@ -535,7 +535,7 @@ class IndexingRunner:
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
target=self._process_keyword_index,
|
||||
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
|
||||
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
|
||||
)
|
||||
create_keyword_thread.start()
|
||||
|
||||
|
@ -258,78 +258,79 @@ class RetrievalService:
|
||||
include_segment_ids = []
|
||||
segment_child_map = {}
|
||||
for document in documents:
|
||||
document_id = document.metadata["document_id"]
|
||||
document_id = document.metadata.get("document_id")
|
||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
||||
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_index_node_id = document.metadata["doc_id"]
|
||||
result = (
|
||||
db.session.query(ChildChunk, DocumentSegment)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
.filter(
|
||||
ChildChunk.index_node_id == child_index_node_id,
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
result = (
|
||||
db.session.query(ChildChunk, DocumentSegment)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
.filter(
|
||||
ChildChunk.index_node_id == child_index_node_id,
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if result:
|
||||
child_chunk, segment = result
|
||||
if result:
|
||||
child_chunk, segment = result
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.append(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.append(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
include_segment_ids.append(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score", None),
|
||||
}
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
include_segment_ids.append(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score", None),
|
||||
}
|
||||
|
||||
records.append(record)
|
||||
records.append(record)
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
|
||||
|
@ -122,26 +122,27 @@ class DatasetDocumentStore:
|
||||
db.session.add(segment_document)
|
||||
db.session.flush()
|
||||
if save_child:
|
||||
for postion, child in enumerate(doc.children, start=1):
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
document_id=self._document_id,
|
||||
segment_id=segment_document.id,
|
||||
position=postion,
|
||||
index_node_id=child.metadata["doc_id"],
|
||||
index_node_hash=child.metadata["doc_hash"],
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type="automatic",
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
if doc.children:
|
||||
for postion, child in enumerate(doc.children, start=1):
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
document_id=self._document_id,
|
||||
segment_id=segment_document.id,
|
||||
position=postion,
|
||||
index_node_id=child.metadata.get("doc_id"),
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type="automatic",
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
else:
|
||||
segment_document.content = doc.page_content
|
||||
if doc.metadata.get("answer"):
|
||||
segment_document.answer = doc.metadata.pop("answer", "")
|
||||
segment_document.index_node_hash = doc.metadata["doc_hash"]
|
||||
segment_document.index_node_hash = doc.metadata.get("doc_hash")
|
||||
segment_document.word_count = len(doc.page_content)
|
||||
segment_document.tokens = tokens
|
||||
if save_child and doc.children:
|
||||
@ -160,8 +161,8 @@ class DatasetDocumentStore:
|
||||
document_id=self._document_id,
|
||||
segment_id=segment_document.id,
|
||||
position=position,
|
||||
index_node_id=child.metadata["doc_id"],
|
||||
index_node_hash=child.metadata["doc_hash"],
|
||||
index_node_id=child.metadata.get("doc_id"),
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type="automatic",
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
from typing import Optional, cast
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
from openpyxl import load_workbook # type: ignore
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
@ -81,4 +81,4 @@ class BaseIndexProcessor(ABC):
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
)
|
||||
|
||||
return character_splitter
|
||||
return character_splitter # type: ignore
|
||||
|
@ -30,12 +30,18 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
raise ValueError("No process rule found.")
|
||||
if process_rule.get("mode") == "automatic":
|
||||
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
|
||||
rules = Rule(**automatic_rule)
|
||||
else:
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
# Split the text documents into nodes.
|
||||
if not rules.segmentation:
|
||||
raise ValueError("No segmentation found in rules.")
|
||||
splitter = self._get_splitter(
|
||||
processing_rule_mode=process_rule.get("mode"),
|
||||
max_tokens=rules.segmentation.max_tokens,
|
||||
|
@ -30,8 +30,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
raise ValueError("No process rule found.")
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
all_documents = []
|
||||
all_documents = [] # type: ignore
|
||||
if rules.parent_mode == ParentMode.PARAGRAPH:
|
||||
# Split the text documents into nodes.
|
||||
splitter = self._get_splitter(
|
||||
@ -161,6 +165,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
process_rule_mode: str,
|
||||
embedding_model_instance: Optional[ModelInstance],
|
||||
) -> list[ChildDocument]:
|
||||
if not rules.subchunk_segmentation:
|
||||
raise ValueError("No subchunk segmentation found in rules.")
|
||||
child_splitter = self._get_splitter(
|
||||
processing_rule_mode=process_rule_mode,
|
||||
max_tokens=rules.subchunk_segmentation.max_tokens,
|
||||
|
@ -37,12 +37,16 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
preview = kwargs.get("preview")
|
||||
process_rule = kwargs.get("process_rule")
|
||||
if not process_rule:
|
||||
raise ValueError("No process rule found.")
|
||||
if not process_rule.get("rules"):
|
||||
raise ValueError("No rules found in process rule.")
|
||||
rules = Rule(**process_rule.get("rules"))
|
||||
splitter = self._get_splitter(
|
||||
processing_rule_mode=process_rule.get("mode"),
|
||||
max_tokens=rules.segmentation.max_tokens,
|
||||
chunk_overlap=rules.segmentation.chunk_overlap,
|
||||
separator=rules.segmentation.separator,
|
||||
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
|
||||
chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0,
|
||||
separator=rules.segmentation.separator if rules.segmentation else "",
|
||||
embedding_model_instance=kwargs.get("embedding_model_instance"),
|
||||
)
|
||||
|
||||
@ -71,8 +75,8 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
all_documents.extend(split_documents)
|
||||
if preview:
|
||||
self._format_qa_document(
|
||||
current_app._get_current_object(),
|
||||
kwargs.get("tenant_id"),
|
||||
current_app._get_current_object(), # type: ignore
|
||||
kwargs.get("tenant_id"), # type: ignore
|
||||
all_documents[0],
|
||||
all_qa_documents,
|
||||
kwargs.get("doc_language", "English"),
|
||||
@ -85,8 +89,8 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
document_format_thread = threading.Thread(
|
||||
target=self._format_qa_document,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"tenant_id": kwargs.get("tenant_id"),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"tenant_id": kwargs.get("tenant_id"), # type: ignore
|
||||
"document_node": doc,
|
||||
"all_qa_documents": all_qa_documents,
|
||||
"document_language": kwargs.get("doc_language", "English"),
|
||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChildDocument(BaseModel):
|
||||
@ -15,7 +15,7 @@ class ChildDocument(BaseModel):
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
metadata: Optional[dict] = Field(default_factory=dict)
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
@ -28,7 +28,7 @@ class Document(BaseModel):
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
metadata: Optional[dict] = Field(default_factory=dict)
|
||||
metadata: dict = {}
|
||||
|
||||
provider: Optional[str] = "dify"
|
||||
|
||||
|
@ -5,7 +5,7 @@ from dify_app import DifyApp
|
||||
def init_app(app: DifyApp):
|
||||
# register blueprint routers
|
||||
|
||||
from flask_cors import CORS
|
||||
from flask_cors import CORS # type: ignore
|
||||
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
|
@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import render_template
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
@ -27,7 +27,7 @@ def send_document_clean_notify_task():
|
||||
try:
|
||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
|
||||
# group by tenant_id
|
||||
dataset_auto_disable_logs_map = {}
|
||||
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
|
||||
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
||||
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
|
||||
|
||||
@ -37,11 +37,13 @@ def send_document_clean_notify_task():
|
||||
if not tenant:
|
||||
continue
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
if not current_owner_join:
|
||||
continue
|
||||
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
|
||||
if not account:
|
||||
continue
|
||||
|
||||
dataset_auto_dataset_map = {}
|
||||
dataset_auto_dataset_map = {} # type: ignore
|
||||
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
|
||||
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
|
||||
dataset_auto_disable_log.document_id
|
||||
@ -53,14 +55,9 @@ def send_document_clean_notify_task():
|
||||
document_count = len(document_ids)
|
||||
knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")
|
||||
|
||||
html_content = render_template(
|
||||
"clean_document_job_mail_template-US.html",
|
||||
)
|
||||
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Send invite member mail to {} failed".format(to))
|
||||
logging.exception("Send invite member mail to failed")
|
||||
|
@ -4,7 +4,7 @@ from enum import StrEnum
|
||||
from typing import Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
import yaml # type: ignore
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
@ -465,7 +465,7 @@ class AppDslService:
|
||||
else:
|
||||
cls._append_model_config_export_data(export_data, app_model)
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True)
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:
|
||||
|
@ -41,6 +41,7 @@ from models.source import DataSourceOauthBinding
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
ChildChunkUpdateArgs,
|
||||
KnowledgeConfig,
|
||||
RerankingModel,
|
||||
RetrievalModel,
|
||||
SegmentUpdateArgs,
|
||||
)
|
||||
@ -548,12 +549,14 @@ class DocumentService:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
|
||||
document = (
|
||||
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
return document
|
||||
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
|
||||
if document_id:
|
||||
document = (
|
||||
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
return document
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_id(document_id: str) -> Optional[Document]:
|
||||
@ -744,25 +747,26 @@ class DocumentService:
|
||||
if features.billing.enabled:
|
||||
if not knowledge_config.original_document_id:
|
||||
count = 0
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if knowledge_config.data_source:
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list: # type: ignore
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls) # type: ignore
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
@ -789,7 +793,7 @@ class DocumentService:
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
|
||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
@ -801,34 +805,35 @@ class DocumentService:
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
process_rule = knowledge_config.process_rule
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json(),
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
if process_rule:
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
@ -854,7 +859,7 @@ class DocumentService:
|
||||
name=file_name,
|
||||
).first()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
@ -868,7 +873,7 @@ class DocumentService:
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
@ -886,6 +891,8 @@ class DocumentService:
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = Document.query.filter_by(
|
||||
@ -921,7 +928,7 @@ class DocumentService:
|
||||
}
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
@ -944,6 +951,8 @@ class DocumentService:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
@ -959,7 +968,7 @@ class DocumentService:
|
||||
document_name = url
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
@ -1054,7 +1063,7 @@ class DocumentService:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json(),
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
@ -1073,6 +1082,8 @@ class DocumentService:
|
||||
file_name = ""
|
||||
data_source_info = {}
|
||||
if document_data.data_source.info_list.data_source_type == "upload_file":
|
||||
if not document_data.data_source.info_list.file_info_list:
|
||||
raise ValueError("No file info list found.")
|
||||
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
@ -1090,6 +1101,8 @@ class DocumentService:
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
elif document_data.data_source.info_list.data_source_type == "notion_import":
|
||||
if not document_data.data_source.info_list.notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
@ -1107,20 +1120,21 @@ class DocumentService:
|
||||
data_source_info = {
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page.page_id,
|
||||
"notion_page_icon": page.page_icon,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
||||
"type": page.type,
|
||||
}
|
||||
elif document_data.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = document_data.data_source.info_list.website_info_list
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content,
|
||||
"mode": "crawl",
|
||||
}
|
||||
if website_info:
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content, # type: ignore
|
||||
"mode": "crawl",
|
||||
}
|
||||
document.data_source_type = document_data.data_source.info_list.data_source_type
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.name = file_name
|
||||
@ -1155,15 +1169,21 @@ class DocumentService:
|
||||
if features.billing.enabled:
|
||||
count = 0
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
upload_file_list = (
|
||||
knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
if knowledge_config.data_source.info_list.file_info_list
|
||||
else []
|
||||
)
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
if notion_info_list:
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls)
|
||||
if website_info:
|
||||
count = len(website_info.urls)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@ -1174,20 +1194,20 @@ class DocumentService:
|
||||
retrieval_model = None
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
knowledge_config.embedding_model_provider, # type: ignore
|
||||
knowledge_config.embedding_model, # type: ignore
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
if knowledge_config.retrieval_model:
|
||||
retrieval_model = knowledge_config.retrieval_model
|
||||
else:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
retrieval_model = RetrievalModel(**default_retrieval_model)
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=2,
|
||||
score_threshold_enabled=False,
|
||||
)
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@ -1557,12 +1577,12 @@ class SegmentService:
|
||||
raise ValueError("Can't update disabled segment")
|
||||
try:
|
||||
word_count_change = segment.word_count
|
||||
content = args.content
|
||||
content = args.content or segment.content
|
||||
if segment.content == content:
|
||||
segment.word_count = len(content)
|
||||
if document.doc_form == "qa_model":
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer)
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
if args.keywords:
|
||||
segment.keywords = args.keywords
|
||||
@ -1577,7 +1597,12 @@ class SegmentService:
|
||||
db.session.add(document)
|
||||
# update segment index task
|
||||
if args.enabled:
|
||||
VectorService.create_segments_vector([args.keywords], [segment], dataset)
|
||||
VectorService.create_segments_vector(
|
||||
[args.keywords] if args.keywords else None,
|
||||
[segment],
|
||||
dataset,
|
||||
document.doc_form,
|
||||
)
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# regenerate child chunks
|
||||
# get embedding model instance
|
||||
@ -1605,6 +1630,8 @@ class SegmentService:
|
||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
@ -1639,7 +1666,7 @@ class SegmentService:
|
||||
segment.disabled_by = None
|
||||
if document.doc_form == "qa_model":
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer)
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
# update document word count
|
||||
if word_count_change != 0:
|
||||
@ -1673,6 +1700,8 @@ class SegmentService:
|
||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
|
@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel):
|
||||
original_document_id: Optional[str] = None
|
||||
duplicate: bool = True
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
data_source: Optional[DataSource] = None
|
||||
data_source: DataSource
|
||||
process_rule: Optional[ProcessRule] = None
|
||||
retrieval_model: Optional[RetrievalModel] = None
|
||||
doc_form: str = "text_model"
|
||||
|
@ -69,7 +69,7 @@ class HitTestingService:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(query, all_documents)
|
||||
return cls.compact_retrieve_response(query, all_documents) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(
|
||||
|
@ -29,6 +29,8 @@ class VectorService:
|
||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
@ -98,7 +100,7 @@ class VectorService:
|
||||
def generate_child_chunks(
|
||||
cls,
|
||||
segment: DocumentSegment,
|
||||
dataset_document: Document,
|
||||
dataset_document: DatasetDocument,
|
||||
dataset: Dataset,
|
||||
embedding_model_instance: ModelInstance,
|
||||
processing_rule: DatasetProcessRule,
|
||||
@ -130,7 +132,7 @@ class VectorService:
|
||||
doc_language=dataset_document.doc_language,
|
||||
)
|
||||
# save child chunks
|
||||
if len(documents) > 0 and len(documents[0].children) > 0:
|
||||
if documents and documents[0].children:
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
for position, child_chunk in enumerate(documents[0].children, start=1):
|
||||
|
@ -44,7 +44,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
if image_file and image_file.key:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
|
Loading…
x
Reference in New Issue
Block a user