mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 15:15:56 +08:00
py lint (#12102)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
bb35818976
commit
84ac004772
@ -587,7 +587,7 @@ def upgrade_db():
|
|||||||
click.echo(click.style("Starting database migration.", fg="green"))
|
click.echo(click.style("Starting database migration.", fg="green"))
|
||||||
|
|
||||||
# run db migration
|
# run db migration
|
||||||
import flask_migrate
|
import flask_migrate # type: ignore
|
||||||
|
|
||||||
flask_migrate.upgrade()
|
flask_migrate.upgrade()
|
||||||
|
|
||||||
|
@ -413,7 +413,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = indexing_runner.indexing_estimate(
|
estimate_response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_user.current_tenant_id,
|
||||||
[extract_setting],
|
[extract_setting],
|
||||||
data_process_rule_dict,
|
data_process_rule_dict,
|
||||||
@ -421,6 +421,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
"English",
|
"English",
|
||||||
dataset_id,
|
dataset_id,
|
||||||
)
|
)
|
||||||
|
return estimate_response.model_dump(), 200
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
@ -431,7 +432,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise IndexingEstimateError(str(e))
|
raise IndexingEstimateError(str(e))
|
||||||
|
|
||||||
return response.model_dump(), 200
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
class DocumentBatchIndexingEstimateApi(DocumentResource):
|
class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||||
@ -521,6 +522,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
"English",
|
"English",
|
||||||
dataset_id,
|
dataset_id,
|
||||||
)
|
)
|
||||||
|
return response.model_dump(), 200
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
raise ProviderNotInitializeError(
|
raise ProviderNotInitializeError(
|
||||||
"No Embedding Model available. Please configure a valid provider "
|
"No Embedding Model available. Please configure a valid provider "
|
||||||
@ -530,7 +532,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise IndexingEstimateError(str(e))
|
raise IndexingEstimateError(str(e))
|
||||||
return response.model_dump(), 200
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||||
|
@ -22,6 +22,7 @@ from fields.document_fields import document_fields, document_status_fields
|
|||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
from services.dataset_service import DocumentService
|
from services.dataset_service import DocumentService
|
||||||
|
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
@ -67,13 +68,14 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
}
|
}
|
||||||
args["data_source"] = data_source
|
args["data_source"] = data_source
|
||||||
|
knowledge_config = KnowledgeConfig(**args)
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.document_create_args_validate(args)
|
DocumentService.document_create_args_validate(knowledge_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=args,
|
knowledge_config=knowledge_config,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||||
created_from="api",
|
created_from="api",
|
||||||
@ -122,12 +124,13 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||||||
args["data_source"] = data_source
|
args["data_source"] = data_source
|
||||||
# validate args
|
# validate args
|
||||||
args["original_document_id"] = str(document_id)
|
args["original_document_id"] = str(document_id)
|
||||||
DocumentService.document_create_args_validate(args)
|
knowledge_config = KnowledgeConfig(**args)
|
||||||
|
DocumentService.document_create_args_validate(knowledge_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=args,
|
knowledge_config=knowledge_config,
|
||||||
account=current_user,
|
account=current_user,
|
||||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||||
created_from="api",
|
created_from="api",
|
||||||
@ -186,12 +189,13 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||||
args["data_source"] = data_source
|
args["data_source"] = data_source
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.document_create_args_validate(args)
|
knowledge_config = KnowledgeConfig(**args)
|
||||||
|
DocumentService.document_create_args_validate(knowledge_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=args,
|
knowledge_config=knowledge_config,
|
||||||
account=dataset.created_by_account,
|
account=dataset.created_by_account,
|
||||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||||
created_from="api",
|
created_from="api",
|
||||||
@ -245,12 +249,14 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||||||
args["data_source"] = data_source
|
args["data_source"] = data_source
|
||||||
# validate args
|
# validate args
|
||||||
args["original_document_id"] = str(document_id)
|
args["original_document_id"] = str(document_id)
|
||||||
DocumentService.document_create_args_validate(args)
|
|
||||||
|
knowledge_config = KnowledgeConfig(**args)
|
||||||
|
DocumentService.document_create_args_validate(knowledge_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document_data=args,
|
knowledge_config=knowledge_config,
|
||||||
account=dataset.created_by_account,
|
account=dataset.created_by_account,
|
||||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||||
created_from="api",
|
created_from="api",
|
||||||
|
@ -276,7 +276,7 @@ class IndexingRunner:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
)
|
||||||
preview_texts = []
|
preview_texts = [] # type: ignore
|
||||||
|
|
||||||
total_segments = 0
|
total_segments = 0
|
||||||
index_type = doc_form
|
index_type = doc_form
|
||||||
@ -300,13 +300,13 @@ class IndexingRunner:
|
|||||||
if len(preview_texts) < 10:
|
if len(preview_texts) < 10:
|
||||||
if doc_form and doc_form == "qa_model":
|
if doc_form and doc_form == "qa_model":
|
||||||
preview_detail = QAPreviewDetail(
|
preview_detail = QAPreviewDetail(
|
||||||
question=document.page_content, answer=document.metadata.get("answer")
|
question=document.page_content, answer=document.metadata.get("answer") or ""
|
||||||
)
|
)
|
||||||
preview_texts.append(preview_detail)
|
preview_texts.append(preview_detail)
|
||||||
else:
|
else:
|
||||||
preview_detail = PreviewDetail(content=document.page_content)
|
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
|
||||||
if document.children:
|
if document.children:
|
||||||
preview_detail.child_chunks = [child.page_content for child in document.children]
|
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
|
||||||
preview_texts.append(preview_detail)
|
preview_texts.append(preview_detail)
|
||||||
|
|
||||||
# delete image files and related db records
|
# delete image files and related db records
|
||||||
@ -325,7 +325,7 @@ class IndexingRunner:
|
|||||||
|
|
||||||
if doc_form and doc_form == "qa_model":
|
if doc_form and doc_form == "qa_model":
|
||||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
|
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
|
||||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
|
||||||
|
|
||||||
def _extract(
|
def _extract(
|
||||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||||
@ -454,7 +454,7 @@ class IndexingRunner:
|
|||||||
embedding_model_instance=embedding_model_instance,
|
embedding_model_instance=embedding_model_instance,
|
||||||
)
|
)
|
||||||
|
|
||||||
return character_splitter
|
return character_splitter # type: ignore
|
||||||
|
|
||||||
def _split_to_documents_for_estimate(
|
def _split_to_documents_for_estimate(
|
||||||
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
|
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
|
||||||
@ -535,7 +535,7 @@ class IndexingRunner:
|
|||||||
# create keyword index
|
# create keyword index
|
||||||
create_keyword_thread = threading.Thread(
|
create_keyword_thread = threading.Thread(
|
||||||
target=self._process_keyword_index,
|
target=self._process_keyword_index,
|
||||||
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
|
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
|
||||||
)
|
)
|
||||||
create_keyword_thread.start()
|
create_keyword_thread.start()
|
||||||
|
|
||||||
|
@ -258,10 +258,11 @@ class RetrievalService:
|
|||||||
include_segment_ids = []
|
include_segment_ids = []
|
||||||
segment_child_map = {}
|
segment_child_map = {}
|
||||||
for document in documents:
|
for document in documents:
|
||||||
document_id = document.metadata["document_id"]
|
document_id = document.metadata.get("document_id")
|
||||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
||||||
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
if dataset_document:
|
||||||
child_index_node_id = document.metadata["doc_id"]
|
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||||
|
child_index_node_id = document.metadata.get("doc_id")
|
||||||
result = (
|
result = (
|
||||||
db.session.query(ChildChunk, DocumentSegment)
|
db.session.query(ChildChunk, DocumentSegment)
|
||||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||||
|
@ -122,6 +122,7 @@ class DatasetDocumentStore:
|
|||||||
db.session.add(segment_document)
|
db.session.add(segment_document)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
if save_child:
|
if save_child:
|
||||||
|
if doc.children:
|
||||||
for postion, child in enumerate(doc.children, start=1):
|
for postion, child in enumerate(doc.children, start=1):
|
||||||
child_segment = ChildChunk(
|
child_segment = ChildChunk(
|
||||||
tenant_id=self._dataset.tenant_id,
|
tenant_id=self._dataset.tenant_id,
|
||||||
@ -129,8 +130,8 @@ class DatasetDocumentStore:
|
|||||||
document_id=self._document_id,
|
document_id=self._document_id,
|
||||||
segment_id=segment_document.id,
|
segment_id=segment_document.id,
|
||||||
position=postion,
|
position=postion,
|
||||||
index_node_id=child.metadata["doc_id"],
|
index_node_id=child.metadata.get("doc_id"),
|
||||||
index_node_hash=child.metadata["doc_hash"],
|
index_node_hash=child.metadata.get("doc_hash"),
|
||||||
content=child.page_content,
|
content=child.page_content,
|
||||||
word_count=len(child.page_content),
|
word_count=len(child.page_content),
|
||||||
type="automatic",
|
type="automatic",
|
||||||
@ -141,7 +142,7 @@ class DatasetDocumentStore:
|
|||||||
segment_document.content = doc.page_content
|
segment_document.content = doc.page_content
|
||||||
if doc.metadata.get("answer"):
|
if doc.metadata.get("answer"):
|
||||||
segment_document.answer = doc.metadata.pop("answer", "")
|
segment_document.answer = doc.metadata.pop("answer", "")
|
||||||
segment_document.index_node_hash = doc.metadata["doc_hash"]
|
segment_document.index_node_hash = doc.metadata.get("doc_hash")
|
||||||
segment_document.word_count = len(doc.page_content)
|
segment_document.word_count = len(doc.page_content)
|
||||||
segment_document.tokens = tokens
|
segment_document.tokens = tokens
|
||||||
if save_child and doc.children:
|
if save_child and doc.children:
|
||||||
@ -160,8 +161,8 @@ class DatasetDocumentStore:
|
|||||||
document_id=self._document_id,
|
document_id=self._document_id,
|
||||||
segment_id=segment_document.id,
|
segment_id=segment_document.id,
|
||||||
position=position,
|
position=position,
|
||||||
index_node_id=child.metadata["doc_id"],
|
index_node_id=child.metadata.get("doc_id"),
|
||||||
index_node_hash=child.metadata["doc_hash"],
|
index_node_hash=child.metadata.get("doc_hash"),
|
||||||
content=child.page_content,
|
content=child.page_content,
|
||||||
word_count=len(child.page_content),
|
word_count=len(child.page_content),
|
||||||
type="automatic",
|
type="automatic",
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from openpyxl import load_workbook
|
from openpyxl import load_workbook # type: ignore
|
||||||
|
|
||||||
from core.rag.extractor.extractor_base import BaseExtractor
|
from core.rag.extractor.extractor_base import BaseExtractor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
@ -81,4 +81,4 @@ class BaseIndexProcessor(ABC):
|
|||||||
embedding_model_instance=embedding_model_instance,
|
embedding_model_instance=embedding_model_instance,
|
||||||
)
|
)
|
||||||
|
|
||||||
return character_splitter
|
return character_splitter # type: ignore
|
||||||
|
@ -30,12 +30,18 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
|
|
||||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||||
process_rule = kwargs.get("process_rule")
|
process_rule = kwargs.get("process_rule")
|
||||||
|
if not process_rule:
|
||||||
|
raise ValueError("No process rule found.")
|
||||||
if process_rule.get("mode") == "automatic":
|
if process_rule.get("mode") == "automatic":
|
||||||
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
|
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
|
||||||
rules = Rule(**automatic_rule)
|
rules = Rule(**automatic_rule)
|
||||||
else:
|
else:
|
||||||
|
if not process_rule.get("rules"):
|
||||||
|
raise ValueError("No rules found in process rule.")
|
||||||
rules = Rule(**process_rule.get("rules"))
|
rules = Rule(**process_rule.get("rules"))
|
||||||
# Split the text documents into nodes.
|
# Split the text documents into nodes.
|
||||||
|
if not rules.segmentation:
|
||||||
|
raise ValueError("No segmentation found in rules.")
|
||||||
splitter = self._get_splitter(
|
splitter = self._get_splitter(
|
||||||
processing_rule_mode=process_rule.get("mode"),
|
processing_rule_mode=process_rule.get("mode"),
|
||||||
max_tokens=rules.segmentation.max_tokens,
|
max_tokens=rules.segmentation.max_tokens,
|
||||||
|
@ -30,8 +30,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
|
|
||||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||||
process_rule = kwargs.get("process_rule")
|
process_rule = kwargs.get("process_rule")
|
||||||
|
if not process_rule:
|
||||||
|
raise ValueError("No process rule found.")
|
||||||
|
if not process_rule.get("rules"):
|
||||||
|
raise ValueError("No rules found in process rule.")
|
||||||
rules = Rule(**process_rule.get("rules"))
|
rules = Rule(**process_rule.get("rules"))
|
||||||
all_documents = []
|
all_documents = [] # type: ignore
|
||||||
if rules.parent_mode == ParentMode.PARAGRAPH:
|
if rules.parent_mode == ParentMode.PARAGRAPH:
|
||||||
# Split the text documents into nodes.
|
# Split the text documents into nodes.
|
||||||
splitter = self._get_splitter(
|
splitter = self._get_splitter(
|
||||||
@ -161,6 +165,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
process_rule_mode: str,
|
process_rule_mode: str,
|
||||||
embedding_model_instance: Optional[ModelInstance],
|
embedding_model_instance: Optional[ModelInstance],
|
||||||
) -> list[ChildDocument]:
|
) -> list[ChildDocument]:
|
||||||
|
if not rules.subchunk_segmentation:
|
||||||
|
raise ValueError("No subchunk segmentation found in rules.")
|
||||||
child_splitter = self._get_splitter(
|
child_splitter = self._get_splitter(
|
||||||
processing_rule_mode=process_rule_mode,
|
processing_rule_mode=process_rule_mode,
|
||||||
max_tokens=rules.subchunk_segmentation.max_tokens,
|
max_tokens=rules.subchunk_segmentation.max_tokens,
|
||||||
|
@ -37,12 +37,16 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||||
preview = kwargs.get("preview")
|
preview = kwargs.get("preview")
|
||||||
process_rule = kwargs.get("process_rule")
|
process_rule = kwargs.get("process_rule")
|
||||||
|
if not process_rule:
|
||||||
|
raise ValueError("No process rule found.")
|
||||||
|
if not process_rule.get("rules"):
|
||||||
|
raise ValueError("No rules found in process rule.")
|
||||||
rules = Rule(**process_rule.get("rules"))
|
rules = Rule(**process_rule.get("rules"))
|
||||||
splitter = self._get_splitter(
|
splitter = self._get_splitter(
|
||||||
processing_rule_mode=process_rule.get("mode"),
|
processing_rule_mode=process_rule.get("mode"),
|
||||||
max_tokens=rules.segmentation.max_tokens,
|
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
|
||||||
chunk_overlap=rules.segmentation.chunk_overlap,
|
chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0,
|
||||||
separator=rules.segmentation.separator,
|
separator=rules.segmentation.separator if rules.segmentation else "",
|
||||||
embedding_model_instance=kwargs.get("embedding_model_instance"),
|
embedding_model_instance=kwargs.get("embedding_model_instance"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,8 +75,8 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
all_documents.extend(split_documents)
|
all_documents.extend(split_documents)
|
||||||
if preview:
|
if preview:
|
||||||
self._format_qa_document(
|
self._format_qa_document(
|
||||||
current_app._get_current_object(),
|
current_app._get_current_object(), # type: ignore
|
||||||
kwargs.get("tenant_id"),
|
kwargs.get("tenant_id"), # type: ignore
|
||||||
all_documents[0],
|
all_documents[0],
|
||||||
all_qa_documents,
|
all_qa_documents,
|
||||||
kwargs.get("doc_language", "English"),
|
kwargs.get("doc_language", "English"),
|
||||||
@ -85,8 +89,8 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
document_format_thread = threading.Thread(
|
document_format_thread = threading.Thread(
|
||||||
target=self._format_qa_document,
|
target=self._format_qa_document,
|
||||||
kwargs={
|
kwargs={
|
||||||
"flask_app": current_app._get_current_object(),
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
"tenant_id": kwargs.get("tenant_id"),
|
"tenant_id": kwargs.get("tenant_id"), # type: ignore
|
||||||
"document_node": doc,
|
"document_node": doc,
|
||||||
"all_qa_documents": all_qa_documents,
|
"all_qa_documents": all_qa_documents,
|
||||||
"document_language": kwargs.get("doc_language", "English"),
|
"document_language": kwargs.get("doc_language", "English"),
|
||||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ChildDocument(BaseModel):
|
class ChildDocument(BaseModel):
|
||||||
@ -15,7 +15,7 @@ class ChildDocument(BaseModel):
|
|||||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||||
documents, etc.).
|
documents, etc.).
|
||||||
"""
|
"""
|
||||||
metadata: Optional[dict] = Field(default_factory=dict)
|
metadata: dict = {}
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
class Document(BaseModel):
|
||||||
@ -28,7 +28,7 @@ class Document(BaseModel):
|
|||||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||||
documents, etc.).
|
documents, etc.).
|
||||||
"""
|
"""
|
||||||
metadata: Optional[dict] = Field(default_factory=dict)
|
metadata: dict = {}
|
||||||
|
|
||||||
provider: Optional[str] = "dify"
|
provider: Optional[str] = "dify"
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from dify_app import DifyApp
|
|||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
# register blueprint routers
|
# register blueprint routers
|
||||||
|
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS # type: ignore
|
||||||
|
|
||||||
from controllers.console import bp as console_app_bp
|
from controllers.console import bp as console_app_bp
|
||||||
from controllers.files import bp as files_bp
|
from controllers.files import bp as files_bp
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task # type: ignore
|
from celery import shared_task # type: ignore
|
||||||
from flask import render_template
|
|
||||||
|
|
||||||
from extensions.ext_mail import mail
|
from extensions.ext_mail import mail
|
||||||
from models.account import Account, Tenant, TenantAccountJoin
|
from models.account import Account, Tenant, TenantAccountJoin
|
||||||
@ -27,7 +27,7 @@ def send_document_clean_notify_task():
|
|||||||
try:
|
try:
|
||||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
|
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
|
||||||
# group by tenant_id
|
# group by tenant_id
|
||||||
dataset_auto_disable_logs_map = {}
|
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
|
||||||
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
||||||
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
|
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
|
||||||
|
|
||||||
@ -37,11 +37,13 @@ def send_document_clean_notify_task():
|
|||||||
if not tenant:
|
if not tenant:
|
||||||
continue
|
continue
|
||||||
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
||||||
|
if not current_owner_join:
|
||||||
|
continue
|
||||||
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
|
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
|
||||||
if not account:
|
if not account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dataset_auto_dataset_map = {}
|
dataset_auto_dataset_map = {} # type: ignore
|
||||||
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
|
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
|
||||||
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
|
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
|
||||||
dataset_auto_disable_log.document_id
|
dataset_auto_disable_log.document_id
|
||||||
@ -53,14 +55,9 @@ def send_document_clean_notify_task():
|
|||||||
document_count = len(document_ids)
|
document_count = len(document_ids)
|
||||||
knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")
|
knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")
|
||||||
|
|
||||||
html_content = render_template(
|
|
||||||
"clean_document_job_mail_template-US.html",
|
|
||||||
)
|
|
||||||
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
|
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
|
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Send invite member mail to {} failed".format(to))
|
logging.exception("Send invite member mail to failed")
|
||||||
|
@ -4,7 +4,7 @@ from enum import StrEnum
|
|||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import yaml
|
import yaml # type: ignore
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -465,7 +465,7 @@ class AppDslService:
|
|||||||
else:
|
else:
|
||||||
cls._append_model_config_export_data(export_data, app_model)
|
cls._append_model_config_export_data(export_data, app_model)
|
||||||
|
|
||||||
return yaml.dump(export_data, allow_unicode=True)
|
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:
|
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:
|
||||||
|
@ -41,6 +41,7 @@ from models.source import DataSourceOauthBinding
|
|||||||
from services.entities.knowledge_entities.knowledge_entities import (
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
ChildChunkUpdateArgs,
|
ChildChunkUpdateArgs,
|
||||||
KnowledgeConfig,
|
KnowledgeConfig,
|
||||||
|
RerankingModel,
|
||||||
RetrievalModel,
|
RetrievalModel,
|
||||||
SegmentUpdateArgs,
|
SegmentUpdateArgs,
|
||||||
)
|
)
|
||||||
@ -548,12 +549,14 @@ class DocumentService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
|
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
|
||||||
|
if document_id:
|
||||||
document = (
|
document = (
|
||||||
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||||
)
|
)
|
||||||
|
|
||||||
return document
|
return document
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document_by_id(document_id: str) -> Optional[Document]:
|
def get_document_by_id(document_id: str) -> Optional[Document]:
|
||||||
@ -744,16 +747,17 @@ class DocumentService:
|
|||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if not knowledge_config.original_document_id:
|
if not knowledge_config.original_document_id:
|
||||||
count = 0
|
count = 0
|
||||||
|
if knowledge_config.data_source:
|
||||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||||
count = len(upload_file_list)
|
count = len(upload_file_list)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list: # type: ignore
|
||||||
count = count + len(notion_info.pages)
|
count = count + len(notion_info.pages)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||||
count = len(website_info.urls)
|
count = len(website_info.urls) # type: ignore
|
||||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||||
if count > batch_upload_limit:
|
if count > batch_upload_limit:
|
||||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||||
@ -762,7 +766,7 @@ class DocumentService:
|
|||||||
|
|
||||||
# if dataset is empty, update dataset data_source_type
|
# if dataset is empty, update dataset data_source_type
|
||||||
if not dataset.data_source_type:
|
if not dataset.data_source_type:
|
||||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||||
|
|
||||||
if not dataset.indexing_technique:
|
if not dataset.indexing_technique:
|
||||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||||
@ -789,7 +793,7 @@ class DocumentService:
|
|||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
|
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
if knowledge_config.original_document_id:
|
if knowledge_config.original_document_id:
|
||||||
@ -801,11 +805,12 @@ class DocumentService:
|
|||||||
# save process rule
|
# save process rule
|
||||||
if not dataset_process_rule:
|
if not dataset_process_rule:
|
||||||
process_rule = knowledge_config.process_rule
|
process_rule = knowledge_config.process_rule
|
||||||
|
if process_rule:
|
||||||
if process_rule.mode in ("custom", "hierarchical"):
|
if process_rule.mode in ("custom", "hierarchical"):
|
||||||
dataset_process_rule = DatasetProcessRule(
|
dataset_process_rule = DatasetProcessRule(
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
mode=process_rule.mode,
|
mode=process_rule.mode,
|
||||||
rules=process_rule.rules.model_dump_json(),
|
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
)
|
)
|
||||||
elif process_rule.mode == "automatic":
|
elif process_rule.mode == "automatic":
|
||||||
@ -817,7 +822,7 @@ class DocumentService:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.warn(
|
logging.warn(
|
||||||
f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
|
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
db.session.add(dataset_process_rule)
|
db.session.add(dataset_process_rule)
|
||||||
@ -828,7 +833,7 @@ class DocumentService:
|
|||||||
document_ids = []
|
document_ids = []
|
||||||
duplicate_document_ids = []
|
duplicate_document_ids = []
|
||||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||||
for file_id in upload_file_list:
|
for file_id in upload_file_list:
|
||||||
file = (
|
file = (
|
||||||
db.session.query(UploadFile)
|
db.session.query(UploadFile)
|
||||||
@ -854,7 +859,7 @@ class DocumentService:
|
|||||||
name=file_name,
|
name=file_name,
|
||||||
).first()
|
).first()
|
||||||
if document:
|
if document:
|
||||||
document.dataset_process_rule_id = dataset_process_rule.id
|
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||||
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
document.created_from = created_from
|
document.created_from = created_from
|
||||||
document.doc_form = knowledge_config.doc_form
|
document.doc_form = knowledge_config.doc_form
|
||||||
@ -868,7 +873,7 @@ class DocumentService:
|
|||||||
continue
|
continue
|
||||||
document = DocumentService.build_document(
|
document = DocumentService.build_document(
|
||||||
dataset,
|
dataset,
|
||||||
dataset_process_rule.id,
|
dataset_process_rule.id, # type: ignore
|
||||||
knowledge_config.data_source.info_list.data_source_type,
|
knowledge_config.data_source.info_list.data_source_type,
|
||||||
knowledge_config.doc_form,
|
knowledge_config.doc_form,
|
||||||
knowledge_config.doc_language,
|
knowledge_config.doc_language,
|
||||||
@ -886,6 +891,8 @@ class DocumentService:
|
|||||||
position += 1
|
position += 1
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||||
|
if not notion_info_list:
|
||||||
|
raise ValueError("No notion info list found.")
|
||||||
exist_page_ids = []
|
exist_page_ids = []
|
||||||
exist_document = {}
|
exist_document = {}
|
||||||
documents = Document.query.filter_by(
|
documents = Document.query.filter_by(
|
||||||
@ -921,7 +928,7 @@ class DocumentService:
|
|||||||
}
|
}
|
||||||
document = DocumentService.build_document(
|
document = DocumentService.build_document(
|
||||||
dataset,
|
dataset,
|
||||||
dataset_process_rule.id,
|
dataset_process_rule.id, # type: ignore
|
||||||
knowledge_config.data_source.info_list.data_source_type,
|
knowledge_config.data_source.info_list.data_source_type,
|
||||||
knowledge_config.doc_form,
|
knowledge_config.doc_form,
|
||||||
knowledge_config.doc_language,
|
knowledge_config.doc_language,
|
||||||
@ -944,6 +951,8 @@ class DocumentService:
|
|||||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||||
|
if not website_info:
|
||||||
|
raise ValueError("No website info list found.")
|
||||||
urls = website_info.urls
|
urls = website_info.urls
|
||||||
for url in urls:
|
for url in urls:
|
||||||
data_source_info = {
|
data_source_info = {
|
||||||
@ -959,7 +968,7 @@ class DocumentService:
|
|||||||
document_name = url
|
document_name = url
|
||||||
document = DocumentService.build_document(
|
document = DocumentService.build_document(
|
||||||
dataset,
|
dataset,
|
||||||
dataset_process_rule.id,
|
dataset_process_rule.id, # type: ignore
|
||||||
knowledge_config.data_source.info_list.data_source_type,
|
knowledge_config.data_source.info_list.data_source_type,
|
||||||
knowledge_config.doc_form,
|
knowledge_config.doc_form,
|
||||||
knowledge_config.doc_language,
|
knowledge_config.doc_language,
|
||||||
@ -1054,7 +1063,7 @@ class DocumentService:
|
|||||||
dataset_process_rule = DatasetProcessRule(
|
dataset_process_rule = DatasetProcessRule(
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
mode=process_rule.mode,
|
mode=process_rule.mode,
|
||||||
rules=process_rule.rules.model_dump_json(),
|
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
)
|
)
|
||||||
elif process_rule.mode == "automatic":
|
elif process_rule.mode == "automatic":
|
||||||
@ -1073,6 +1082,8 @@ class DocumentService:
|
|||||||
file_name = ""
|
file_name = ""
|
||||||
data_source_info = {}
|
data_source_info = {}
|
||||||
if document_data.data_source.info_list.data_source_type == "upload_file":
|
if document_data.data_source.info_list.data_source_type == "upload_file":
|
||||||
|
if not document_data.data_source.info_list.file_info_list:
|
||||||
|
raise ValueError("No file info list found.")
|
||||||
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
||||||
for file_id in upload_file_list:
|
for file_id in upload_file_list:
|
||||||
file = (
|
file = (
|
||||||
@ -1090,6 +1101,8 @@ class DocumentService:
|
|||||||
"upload_file_id": file_id,
|
"upload_file_id": file_id,
|
||||||
}
|
}
|
||||||
elif document_data.data_source.info_list.data_source_type == "notion_import":
|
elif document_data.data_source.info_list.data_source_type == "notion_import":
|
||||||
|
if not document_data.data_source.info_list.notion_info_list:
|
||||||
|
raise ValueError("No notion info list found.")
|
||||||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info.workspace_id
|
workspace_id = notion_info.workspace_id
|
||||||
@ -1107,18 +1120,19 @@ class DocumentService:
|
|||||||
data_source_info = {
|
data_source_info = {
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_page_id": page.page_id,
|
"notion_page_id": page.page_id,
|
||||||
"notion_page_icon": page.page_icon,
|
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
||||||
"type": page.type,
|
"type": page.type,
|
||||||
}
|
}
|
||||||
elif document_data.data_source.info_list.data_source_type == "website_crawl":
|
elif document_data.data_source.info_list.data_source_type == "website_crawl":
|
||||||
website_info = document_data.data_source.info_list.website_info_list
|
website_info = document_data.data_source.info_list.website_info_list
|
||||||
|
if website_info:
|
||||||
urls = website_info.urls
|
urls = website_info.urls
|
||||||
for url in urls:
|
for url in urls:
|
||||||
data_source_info = {
|
data_source_info = {
|
||||||
"url": url,
|
"url": url,
|
||||||
"provider": website_info.provider,
|
"provider": website_info.provider,
|
||||||
"job_id": website_info.job_id,
|
"job_id": website_info.job_id,
|
||||||
"only_main_content": website_info.only_main_content,
|
"only_main_content": website_info.only_main_content, # type: ignore
|
||||||
"mode": "crawl",
|
"mode": "crawl",
|
||||||
}
|
}
|
||||||
document.data_source_type = document_data.data_source.info_list.data_source_type
|
document.data_source_type = document_data.data_source.info_list.data_source_type
|
||||||
@ -1155,14 +1169,20 @@ class DocumentService:
|
|||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
count = 0
|
count = 0
|
||||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
upload_file_list = (
|
||||||
|
knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||||
|
if knowledge_config.data_source.info_list.file_info_list
|
||||||
|
else []
|
||||||
|
)
|
||||||
count = len(upload_file_list)
|
count = len(upload_file_list)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||||
|
if notion_info_list:
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
count = count + len(notion_info.pages)
|
count = count + len(notion_info.pages)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||||
|
if website_info:
|
||||||
count = len(website_info.urls)
|
count = len(website_info.urls)
|
||||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||||
if count > batch_upload_limit:
|
if count > batch_upload_limit:
|
||||||
@ -1174,20 +1194,20 @@ class DocumentService:
|
|||||||
retrieval_model = None
|
retrieval_model = None
|
||||||
if knowledge_config.indexing_technique == "high_quality":
|
if knowledge_config.indexing_technique == "high_quality":
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
knowledge_config.embedding_model_provider, # type: ignore
|
||||||
|
knowledge_config.embedding_model, # type: ignore
|
||||||
)
|
)
|
||||||
dataset_collection_binding_id = dataset_collection_binding.id
|
dataset_collection_binding_id = dataset_collection_binding.id
|
||||||
if knowledge_config.retrieval_model:
|
if knowledge_config.retrieval_model:
|
||||||
retrieval_model = knowledge_config.retrieval_model
|
retrieval_model = knowledge_config.retrieval_model
|
||||||
else:
|
else:
|
||||||
default_retrieval_model = {
|
retrieval_model = RetrievalModel(
|
||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
"reranking_enable": False,
|
reranking_enable=False,
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||||
"top_k": 2,
|
top_k=2,
|
||||||
"score_threshold_enabled": False,
|
score_threshold_enabled=False,
|
||||||
}
|
)
|
||||||
retrieval_model = RetrievalModel(**default_retrieval_model)
|
|
||||||
# save dataset
|
# save dataset
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -1557,12 +1577,12 @@ class SegmentService:
|
|||||||
raise ValueError("Can't update disabled segment")
|
raise ValueError("Can't update disabled segment")
|
||||||
try:
|
try:
|
||||||
word_count_change = segment.word_count
|
word_count_change = segment.word_count
|
||||||
content = args.content
|
content = args.content or segment.content
|
||||||
if segment.content == content:
|
if segment.content == content:
|
||||||
segment.word_count = len(content)
|
segment.word_count = len(content)
|
||||||
if document.doc_form == "qa_model":
|
if document.doc_form == "qa_model":
|
||||||
segment.answer = args.answer
|
segment.answer = args.answer
|
||||||
segment.word_count += len(args.answer)
|
segment.word_count += len(args.answer) if args.answer else 0
|
||||||
word_count_change = segment.word_count - word_count_change
|
word_count_change = segment.word_count - word_count_change
|
||||||
if args.keywords:
|
if args.keywords:
|
||||||
segment.keywords = args.keywords
|
segment.keywords = args.keywords
|
||||||
@ -1577,7 +1597,12 @@ class SegmentService:
|
|||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
# update segment index task
|
# update segment index task
|
||||||
if args.enabled:
|
if args.enabled:
|
||||||
VectorService.create_segments_vector([args.keywords], [segment], dataset)
|
VectorService.create_segments_vector(
|
||||||
|
[args.keywords] if args.keywords else None,
|
||||||
|
[segment],
|
||||||
|
dataset,
|
||||||
|
document.doc_form,
|
||||||
|
)
|
||||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||||
# regenerate child chunks
|
# regenerate child chunks
|
||||||
# get embedding model instance
|
# get embedding model instance
|
||||||
@ -1605,6 +1630,8 @@ class SegmentService:
|
|||||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
if not processing_rule:
|
||||||
|
raise ValueError("No processing rule found.")
|
||||||
VectorService.generate_child_chunks(
|
VectorService.generate_child_chunks(
|
||||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||||
)
|
)
|
||||||
@ -1639,7 +1666,7 @@ class SegmentService:
|
|||||||
segment.disabled_by = None
|
segment.disabled_by = None
|
||||||
if document.doc_form == "qa_model":
|
if document.doc_form == "qa_model":
|
||||||
segment.answer = args.answer
|
segment.answer = args.answer
|
||||||
segment.word_count += len(args.answer)
|
segment.word_count += len(args.answer) if args.answer else 0
|
||||||
word_count_change = segment.word_count - word_count_change
|
word_count_change = segment.word_count - word_count_change
|
||||||
# update document word count
|
# update document word count
|
||||||
if word_count_change != 0:
|
if word_count_change != 0:
|
||||||
@ -1673,6 +1700,8 @@ class SegmentService:
|
|||||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
if not processing_rule:
|
||||||
|
raise ValueError("No processing rule found.")
|
||||||
VectorService.generate_child_chunks(
|
VectorService.generate_child_chunks(
|
||||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||||
)
|
)
|
||||||
|
@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel):
|
|||||||
original_document_id: Optional[str] = None
|
original_document_id: Optional[str] = None
|
||||||
duplicate: bool = True
|
duplicate: bool = True
|
||||||
indexing_technique: Literal["high_quality", "economy"]
|
indexing_technique: Literal["high_quality", "economy"]
|
||||||
data_source: Optional[DataSource] = None
|
data_source: DataSource
|
||||||
process_rule: Optional[ProcessRule] = None
|
process_rule: Optional[ProcessRule] = None
|
||||||
retrieval_model: Optional[RetrievalModel] = None
|
retrieval_model: Optional[RetrievalModel] = None
|
||||||
doc_form: str = "text_model"
|
doc_form: str = "text_model"
|
||||||
|
@ -69,7 +69,7 @@ class HitTestingService:
|
|||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return cls.compact_retrieve_response(query, all_documents)
|
return cls.compact_retrieve_response(query, all_documents) # type: ignore
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def external_retrieve(
|
def external_retrieve(
|
||||||
|
@ -29,6 +29,8 @@ class VectorService:
|
|||||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
if not processing_rule:
|
||||||
|
raise ValueError("No processing rule found.")
|
||||||
# get embedding model instance
|
# get embedding model instance
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
@ -98,7 +100,7 @@ class VectorService:
|
|||||||
def generate_child_chunks(
|
def generate_child_chunks(
|
||||||
cls,
|
cls,
|
||||||
segment: DocumentSegment,
|
segment: DocumentSegment,
|
||||||
dataset_document: Document,
|
dataset_document: DatasetDocument,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
embedding_model_instance: ModelInstance,
|
embedding_model_instance: ModelInstance,
|
||||||
processing_rule: DatasetProcessRule,
|
processing_rule: DatasetProcessRule,
|
||||||
@ -130,7 +132,7 @@ class VectorService:
|
|||||||
doc_language=dataset_document.doc_language,
|
doc_language=dataset_document.doc_language,
|
||||||
)
|
)
|
||||||
# save child chunks
|
# save child chunks
|
||||||
if len(documents) > 0 and len(documents[0].children) > 0:
|
if documents and documents[0].children:
|
||||||
index_processor.load(dataset, documents)
|
index_processor.load(dataset, documents)
|
||||||
|
|
||||||
for position, child_chunk in enumerate(documents[0].children, start=1):
|
for position, child_chunk in enumerate(documents[0].children, start=1):
|
||||||
|
@ -44,6 +44,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
|||||||
for upload_file_id in image_upload_file_ids:
|
for upload_file_id in image_upload_file_ids:
|
||||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||||
try:
|
try:
|
||||||
|
if image_file and image_file.key:
|
||||||
storage.delete(image_file.key)
|
storage.delete(image_file.key)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user