Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly 2025-01-08 20:36:22 +08:00
commit fb309462ad
203 changed files with 3469 additions and 2327 deletions

View File

@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

View File

@ -14,7 +14,10 @@ if is_db_command():
app = create_migrations_app()
else:
if os.environ.get("FLASK_DEBUG", "False") != "True":
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey # type: ignore
# gevent

View File

@ -546,6 +546,11 @@ class AuthConfig(BaseSettings):
default=60,
)
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
description="Expiration time for refresh tokens in days",
default=30,
)
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400,
@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings):
default=4000,
)
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
description="Maximum number of child chunks to preview",
default=50,
)
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(

View File

@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
description="Name of the Milvus database to connect to (default is 'default')",
default="default",
)
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.14.2",
default="0.15.0",
)
COMMIT_SHA: str = Field(

View File

@ -57,12 +57,13 @@ class AppListApi(Resource):
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}

View File

@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -273,8 +273,7 @@ FROM
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)

View File

@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.LINDORM

View File

@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)

View File

@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper

View File

@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_user

View File

@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@ -74,7 +73,7 @@ class CompletionApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
@ -133,7 +132,7 @@ class ChatApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")

View File

@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
user=current_user,
source="datasets",
)
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
knowledge_config = KnowledgeConfig(**args)
@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)

View File

@ -1,5 +1,5 @@
from collections.abc import Callable
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import wraps
from typing import Optional
@ -8,6 +8,8 @@ from flask import current_app, request
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
return decorator
def validate_and_get_api_token(scope=None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
"""
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = (
db.session.query(ApiToken)
.filter(
ApiToken.token == auth_token,
ApiToken.type == scope,
current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
)
.first()
)
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()
if not api_token:
raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token

View File

@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value

View File

@ -14,7 +14,11 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from models.model import App, AppMode, EndUser

View File

@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from extensions.ext_database import db
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -68,24 +68,17 @@ from models.account import Account
from models.enums import CreatedByRole
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
class AdvancedChatAppGenerateTaskPipeline:
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream: bool,
dialogue_count: int,
) -> None:
super().__init__(
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._wip_workflow_agent_logs = {}
self._message_cycle_manager = MessageCycleManage(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
self._conversation_name_generate_thread = None
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = ""
self._workflow_run_id: str = ""
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return:
"""
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self._base_task_pipeline._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# init fake graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None
for queue_message in self._queue_manager.listen():
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
with Session(db.engine, expire_on_commit=False) as session:
err = self._base_task_pipeline._handle_error(
event=event, session=session, message_id=self._message_id
)
session.commit()
yield self._error_to_stream_response(err)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_start_to_stream_response(
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
)
node_retry_resp = self._workflow_node_retry_to_stream_response(
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
)
node_start_resp = self._workflow_node_start_to_stream_response(
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent):
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
)
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -368,10 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if node_finish_resp:
yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -385,13 +392,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_start_resp
@ -399,13 +410,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_finish_resp
@ -413,9 +428,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -427,9 +444,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -441,9 +460,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -458,8 +479,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -470,21 +491,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -495,21 +518,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=None,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -521,20 +546,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit()
yield workflow_finish_resp
yield self._error_to_stream_response(err)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -545,7 +572,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -559,18 +586,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._message_cycle_manager._handle_retriever_resources(event)
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
session.commit()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._message_cycle_manager._handle_annotation_reply(event)
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@ -591,23 +618,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(
yield self._message_cycle_manager._message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer
)
# Save message
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit()
@ -627,7 +658,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session)
message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
@ -691,20 +722,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
self._base_task_pipeline._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:
self._output_moderation_handler.append_new_token(text)
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
return False

View File

@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union
from typing import Optional, Union
from sqlalchemy.orm import Session
@ -59,7 +59,6 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
@ -67,16 +66,11 @@ from models.workflow import (
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
class WorkflowAppGenerateTaskPipeline:
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser],
stream: bool,
) -> None:
super().__init__(
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState()
self._workflow_run_id = ""
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self._base_task_pipeline._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
graph_runtime_state = None
for queue_message in self._queue_manager.listen():
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event=event)
yield self._error_to_stream_response(err)
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
start_resp = self._workflow_start_to_stream_response(
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
)
node_start_response = self._workflow_node_start_to_stream_response(
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
node_success_response = self._workflow_node_finish_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_success_response:
yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session,
event=event,
)
node_failed_response = self._workflow_node_finish_to_stream_response(
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_start_resp
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
)
yield parallel_finish_resp
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()

View File

@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
# app config
app_config: WorkflowUIBasedAppConfig
workflow_run_id: Optional[str] = None
workflow_run_id: str
class SingleIterationRunEntity(BaseModel):
"""

View File

@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
from core.app.entities.task_entities import (
ErrorStreamResponse,
PingStreamResponse,
TaskState,
)
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._start_at = time.perf_counter()

View File

@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def __init__(
self,
*,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""

View File

@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
@ -60,13 +59,20 @@ from models.workflow import (
WorkflowRunStatus,
)
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
from .exc import WorkflowRunNotFoundError
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
def _handle_workflow_run_start(
self,
@ -104,7 +110,8 @@ class WorkflowCycleManage:
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id
@ -241,7 +248,7 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution).where(
stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
@ -249,15 +256,18 @@ class WorkflowCycleManage:
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
running_workflow_node_executions = session.scalars(stmt).all()
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
finish_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = finish_at
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
if trace_manager:
trace_manager.add_trace_task(
@ -299,6 +309,8 @@ class WorkflowCycleManage:
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(
@ -325,6 +337,7 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed(
@ -364,6 +377,7 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(
@ -415,6 +429,8 @@ class WorkflowCycleManage:
workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
#################################################
@ -811,25 +827,23 @@ class WorkflowCycleManage:
return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
workflow_node_execution = session.scalar(stmt)
if not workflow_node_execution:
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
return workflow_node_execution
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""

View File

@ -1,9 +1,47 @@
import tiktoken
from threading import Lock
from typing import Any
_tokenizer: Any = None
_lock = Lock()
class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
encoding = tiktoken.encoding_for_model("gpt2")
tiktoken_vec = encoding.encode(text)
return len(tiktoken_vec)
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
#
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
#
try:
import tiktoken
_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer

View File

@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
return False
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")

View File

@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
self._client.delete_collection(self._collection_name)
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(ids=ids)

View File

@ -0,0 +1,104 @@
import json
import logging
from typing import Any, Optional
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)

View File

@ -6,6 +6,8 @@ class Field(Enum):
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"

View File

@ -2,6 +2,7 @@ import json
import logging
from typing import Any, Optional
from packaging import version
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
uri: str
token: Optional[str] = None
user: str
password: str
batch_size: int = 100
database: str = "default"
"""
Configuration class for Milvus connection.
"""
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
return values
def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
"uri": self.uri,
"token": self.token,
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session"
self._fields: list[str] = []
self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False
try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False
def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name):
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name):
return False
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
return len(result) > 0
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields
def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)
return docs
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
"""
Search for documents by vector similarity.
"""
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection
collection_name = self._collection_name
self._client.create_collection(
collection_name=collection_name,
collection_name=self._collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client
class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
),
)

View File

@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
return results.row_count > 0
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)

View File

@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:

View File

@ -167,6 +167,8 @@ class OracleVector(BaseVector):
return docs
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))

View File

@ -129,6 +129,11 @@ class PGVector(BaseVector):
return docs
def delete_by_ids(self, ids: list[str]) -> None:
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))

View File

@ -140,6 +140,8 @@ class TencentVector(BaseVector):
return False
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:

View File

@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "",
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@ -90,6 +90,12 @@ class Vector:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory

View File

@ -16,6 +16,7 @@ class VectorType(StrEnum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
ELASTICSEARCH_JA = "elasticsearch-ja"
LINDORM = "lindorm"
COUCHBASE = "couchbase"
BAIDU = "baidu"

View File

@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
plaintext_file_key = ""
plaintext_file_exists = False
if self._file_cache_key:
try:
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode("utf-8"))
if not plaintext_file_exists and self._file_cache_key:
storage.save(self._file_cache_key, text.encode("utf-8"))
return documents

View File

@ -3,6 +3,7 @@
import uuid
from typing import Optional
from configs import dify_config
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
if kwargs.get("preview"):
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
document.children = child_nodes
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)

View File

@ -212,8 +212,23 @@ class ApiTool(Tool):
else:
body = body
if method in {"get", "head", "post", "put", "delete", "patch"}:
response: httpx.Response = getattr(ssrf_proxy, method)(
if method in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
url,
params=params,
headers=headers,

View File

@ -2,14 +2,18 @@ import csv
import io
import json
import logging
import operator
import os
import tempfile
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
import docx
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from docx.table import Table
from docx.text.paragraph import Paragraph
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
process_data=process_data,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
"""Extract text from a file based on its MIME type."""
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables
for table in doc.tables:
# Table header
try:
# table maybe cause errors so ignore it.
if len(table.rows) > 0 and table.rows[0].cells is not None:
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
# Process paragraphs and tables
for i, paragraph in enumerate(doc.paragraphs):
if paragraph.text.strip():
content_items.append((i, "paragraph", paragraph))
for i, table in enumerate(doc.tables):
content_items.append((i, "table", table))
# Sort content items based on their original position
content_items.sort(key=operator.itemgetter(0))
# Process sorted content
for _, item_type, item in content_items:
if item_type == "paragraph":
if isinstance(item, Table):
continue
text.append(item.text)
elif item_type == "table":
# Process tables
if not isinstance(item, Table):
continue
try:
# Check if any cell in the table has text
has_content = False
for row in table.rows:
for row in item.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
for row in table.rows[1:]:
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
markdown_table = f"| {' | '.join(cell_texts)} |\n"
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
for row in item.rows[1:]:
# Replace newlines with <br> in each cell
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
markdown_table += "| " + " | ".join(row_cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e

View File

@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data.
"""
method: Literal["get", "post", "put", "patch", "delete", "head"]
method: Literal[
"get",
"post",
"put",
"patch",
"delete",
"head",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str
authorization: HttpRequestNodeAuthorization
headers: str

View File

@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
class Executor:
method: Literal["get", "head", "post", "put", "delete", "patch"]
method: Literal[
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str
params: list[tuple[str, str]] | None
content: str | bytes | None
@ -67,12 +82,6 @@ class Executor:
node_data.authorization.config.api_key
).text
# check if node_data.url is a valid URL
if not node_data.url:
raise InvalidURLError("url is required")
if not node_data.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
self.url: str = node_data.url
self.method = node_data.method
self.auth = node_data.authorization
@ -99,6 +108,12 @@ class Executor:
def _init_url(self):
self.url = self.variable_pool.convert_template(self.node_data.url).text
# check if url is a valid URL
if not self.url:
raise InvalidURLError("url is required")
if not self.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
def _init_params(self):
"""
Almost same as _init_headers(), difference:
@ -158,7 +173,10 @@ class Executor:
if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string, strict=False)
try:
json_object = json.loads(json_string, strict=False)
except json.JSONDecodeError as e:
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
self.json = json_object
# self.json = self._parse_object_contains_variables(json_object)
case "binary":
@ -246,7 +264,22 @@ class Executor:
"""
do http request depending on api bundle
"""
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
if self.method not in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
@ -263,7 +296,7 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
# FIXME: fix type ignore, this maybe httpx type issue

View File

@ -340,6 +340,10 @@ class WorkflowEntry:
):
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
# environment variable already exist in variable pool, not from user inputs
if variable_pool.get(variable_selector):
continue
# fetch variable node id from variable selector
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]

View File

@ -33,6 +33,7 @@ else
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
fi

View File

@ -46,7 +46,7 @@ def init_app(app: DifyApp):
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers:
if handler.formatter:

View File

@ -158,7 +158,7 @@ def _build_from_remote_url(
tenant_id: str,
transfer_method: FileTransferMethod,
) -> File:
url = mapping.get("url")
url = mapping.get("url") or mapping.get("remote_url")
if not url:
raise ValueError("Invalid file url")

View File

@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource):
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json()
if response.status_code != 200:
raise ValueError(f"Error fetching block parent page ID: {response_json.message}")
message = response_json.get("message", "unknown error")
raise ValueError(f"Error fetching block parent page ID: {message}")
parent = response_json["parent"]
parent_type = parent["type"]
if parent_type == "block_id":

View File

@ -0,0 +1,41 @@
"""change workflow_runs.total_tokens to bigint
Revision ID: a91b476a53de
Revises: 923752d42eb6
Create Date: 2025-01-01 20:00:01.207369
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a91b476a53de'
down_revision = '923752d42eb6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.INTEGER(),
type_=sa.BigInteger(),
existing_nullable=False,
existing_server_default=sa.text('0'))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.BigInteger(),
type_=sa.INTEGER(),
existing_nullable=False,
existing_server_default=sa.text('0'))
# ### end Alembic commands ###

View File

@ -415,8 +415,8 @@ class WorkflowRun(Base):
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False)

2397
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -71,7 +71,7 @@ pyjwt = "~2.8.0"
pypdfium2 = "~4.30.0"
python = ">=3.11,<3.13"
python-docx = "~1.1.0"
python-dotenv = "1.0.0"
python-dotenv = "1.0.1"
pyyaml = "~6.0.1"
readabilipy = "0.2.0"
redis = { version = "~5.0.3", extras = ["hiredis"] }
@ -82,7 +82,7 @@ scikit-learn = "~1.5.1"
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
sqlalchemy = "~2.0.29"
starlette = "0.41.0"
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
tencentcloud-sdk-python-hunyuan = "~3.0.1294"
tiktoken = "~0.8.0"
tokenizers = "~0.15.0"
transformers = "~4.35.0"
@ -92,7 +92,7 @@ validators = "0.21.0"
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
websocket-client = "~1.7.0"
xinference-client = "0.15.2"
yarl = "~1.9.4"
yarl = "~1.18.3"
youtube-transcript-api = "~0.6.2"
zhipuai = "~2.1.5"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
@ -157,7 +157,7 @@ opensearch-py = "2.4.0"
oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5"
pymilvus = "~2.4.4"
pymilvus = "~2.5.0"
pymochow = "1.3.1"
pyobvector = "~0.1.6"
qdrant-client = "1.7.3"

View File

@ -168,23 +168,6 @@ def clean_unused_datasets_task():
else:
plan = plan_cache.decode()
if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)

View File

@ -66,7 +66,7 @@ class TokenPair(BaseModel):
REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:

View File

@ -2,6 +2,7 @@ import logging
import uuid
from enum import StrEnum
from typing import Optional, cast
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
@ -124,7 +125,7 @@ class AppDslService:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: bytes | str = b""
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return Import(
@ -133,13 +134,17 @@ class AppDslService:
error="yaml_url is required when import_mode is yaml-url",
)
try:
# tricky way to handle url from github to github raw url
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")):
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content
content = response.content.decode()
if len(content) > DSL_MAX_SIZE:
return Import(

View File

@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
"""
Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id
:param args: request args
:return:
@ -44,6 +45,8 @@ class AppService:
elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))

View File

@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@ -17,7 +17,6 @@ class BillingService:
params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
@classmethod
@ -47,12 +46,13 @@ class BillingService:
retry=retry_if_exception_type(httpx.RequestError),
reraise=True,
)
def _send_request(cls, method, endpoint, json=None, params=None):
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json()
@staticmethod

View File

@ -86,7 +86,7 @@ class DatasetService:
else:
return [], 0
else:
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
# show all datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(
@ -382,7 +382,7 @@ class DatasetService:
if dataset.tenant_id != user.current_tenant_id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
@ -404,7 +404,7 @@ class DatasetService:
if not user:
raise ValueError("User not found")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.")
@ -434,6 +434,12 @@ class DatasetService:
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
"document_ids": [],
"count": 0,
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
@ -786,13 +792,19 @@ class DocumentService:
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
@ -804,7 +816,11 @@ class DocumentService:
"score_threshold_enabled": False,
}
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id:

View File

@ -27,7 +27,7 @@ class WorkflowAppService:
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
if keyword:
keyword_like_val = f"%{args['keyword'][:30]}%"
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val),

View File

@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
def test_validate_credentials():
model = GPUStackSpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_model():
model = GPUStackSpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
file = Path(audio_file_path).read_bytes()
result = model.invoke(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
file=file,
)
assert isinstance(result, str)
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@ -0,0 +1,24 @@
import os
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
def test_invoke_model():
model = GPUStackText2SpeechModel()
result = model.invoke(
model="cosyvoice-300m-sft",
tenant_id="test",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
content_text="Hello world",
voice="Chinese Female",
)
content = b""
for chunk in result:
content += chunk
assert content != b""

View File

@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest):
)
def search_by_full_text(self):
# milvus dos not support full text searching yet in < 2.3.x
# milvus support BM25 full text search after version 2.5.0-beta
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
assert len(hits_by_full_text) >= 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)

View File

@ -2,7 +2,7 @@ version: '3'
services:
# API service
api:
image: langgenius/dify-api:0.14.2
image: langgenius/dify-api:0.15.0
restart: always
environment:
# Startup mode, 'api' starts the API server.
@ -227,7 +227,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.14.2
image: langgenius/dify-api:0.15.0
restart: always
environment:
CONSOLE_WEB_URL: ''
@ -397,7 +397,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.14.2
image: langgenius/dify-web:0.15.0
restart: always
environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

View File

@ -105,6 +105,9 @@ FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0
APP_MAX_EXECUTION_TIME=1200
@ -123,10 +126,13 @@ DIFY_PORT=5001
# The number of API server workers, i.e., the number of workers.
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
SERVER_WORKER_AMOUNT=
SERVER_WORKER_AMOUNT=1
# Defaults to gevent. If using windows, it can be switched to sync or solo.
SERVER_WORKER_CLASS=
SERVER_WORKER_CLASS=gevent
# Default number of worker connections, the default is 10.
SERVER_WORKER_CONNECTIONS=10
# Similar to SERVER_WORKER_CLASS.
# If using windows, it can be switched to sync or solo.
@ -377,7 +383,7 @@ SUPABASE_URL=your-server-url
# ------------------------------
# The type of vector store to use.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
VECTOR_STORE=weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
@ -397,6 +403,7 @@ MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_ENABLE_HYBRID_SEARCH=False
# MyScale configuration, only available when VECTOR_STORE is `myscale`
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
@ -506,7 +513,7 @@ TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2
# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch`
ELASTICSEARCH_HOST=0.0.0.0
ELASTICSEARCH_HOST=elasticsearch
ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic
@ -923,6 +930,9 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# The maximum number of top-k value for RAG.
TOP_K_MAX_VALUE=10
# ------------------------------
# Plugin Daemon Configuration
# ------------------------------
@ -947,3 +957,4 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev

View File

@ -73,6 +73,7 @@ services:
CSP_WHITELIST: ${CSP_WHITELIST:-}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
# The postgres database.
db:
@ -92,7 +93,7 @@ services:
volumes:
- ./volumes/db/data:/var/lib/postgresql/data
healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s
timeout: 3s
retries: 30
@ -111,7 +112,7 @@ services:
# Set the redis password when startup redis server.
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
healthcheck:
test: ['CMD', 'redis-cli', 'ping']
test: [ 'CMD', 'redis-cli', 'ping' ]
# The DifySandbox
sandbox:
@ -131,7 +132,7 @@ services:
volumes:
- ./volumes/sandbox/dependencies:/dependencies
healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health']
test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
networks:
- ssrf_proxy_network
@ -167,12 +168,7 @@ services:
volumes:
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
entrypoint:
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
environment:
# pls clearly modify the squid env vars to fit your network environment.
HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
@ -201,8 +197,8 @@ services:
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
entrypoint: ['/docker-entrypoint.sh']
command: ['tail', '-f', '/dev/null']
entrypoint: [ '/docker-entrypoint.sh' ]
command: [ 'tail', '-f', '/dev/null' ]
# The nginx reverse proxy.
# used for reverse proxying the API service and Web service.
@ -219,12 +215,7 @@ services:
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
- ./volumes/certbot/conf:/etc/letsencrypt
- ./volumes/certbot/www:/var/www/html
entrypoint:
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
environment:
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
@ -316,7 +307,7 @@ services:
working_dir: /opt/couchbase
stdin_open: true
tty: true
entrypoint: [""]
entrypoint: [ "" ]
command: sh -c "/opt/couchbase/init/init-cbserver.sh"
volumes:
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
@ -345,7 +336,7 @@ services:
volumes:
- ./volumes/pgvector/data:/var/lib/postgresql/data
healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s
timeout: 3s
retries: 30
@ -367,7 +358,7 @@ services:
volumes:
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s
timeout: 3s
retries: 30
@ -432,7 +423,7 @@ services:
- ./volumes/milvus/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
healthcheck:
test: ['CMD', 'etcdctl', 'endpoint', 'health']
test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
interval: 30s
timeout: 20s
retries: 3
@ -451,7 +442,7 @@ services:
- ./volumes/milvus/minio:/minio_data
command: minio server /minio_data --console-address ":9001"
healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live']
test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
interval: 30s
timeout: 20s
retries: 3
@ -463,7 +454,7 @@ services:
image: milvusdb/milvus:v2.3.1
profiles:
- milvus
command: ['milvus', 'run', 'standalone']
command: [ 'milvus', 'run', 'standalone' ]
environment:
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
@ -471,7 +462,7 @@ services:
volumes:
- ./volumes/milvus/milvus:/var/lib/milvus
healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz']
test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
interval: 30s
start_period: 90s
timeout: 20s
@ -559,7 +550,7 @@ services:
ports:
- ${ELASTICSEARCH_PORT:-9200}:9200
healthcheck:
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty']
test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
interval: 30s
timeout: 10s
retries: 50
@ -587,7 +578,7 @@ services:
ports:
- ${KIBANA_PORT:-5601}:5601
healthcheck:
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1']
test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
interval: 30s
timeout: 10s
retries: 3

View File

@ -27,12 +27,14 @@ x-shared-env: &shared-api-worker-env
MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true}
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30}
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
DIFY_PORT: ${DIFY_PORT:-5001}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent}
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
@ -136,6 +138,7 @@ x-shared-env: &shared-api-worker-env
MILVUS_TOKEN: ${MILVUS_TOKEN:-}
MILVUS_USER: ${MILVUS_USER:-root}
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus}
MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False}
MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
MYSCALE_PORT: ${MYSCALE_PORT:-8123}
MYSCALE_USER: ${MYSCALE_USER:-default}
@ -401,6 +404,7 @@ x-shared-env: &shared-api-worker-env
ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}}
MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10}
services:
# API service
@ -476,6 +480,7 @@ services:
CSP_WHITELIST: ${CSP_WHITELIST:-}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
# The postgres database.
db:
@ -495,7 +500,7 @@ services:
volumes:
- ./volumes/db/data:/var/lib/postgresql/data
healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s
timeout: 3s
retries: 30
@ -514,7 +519,7 @@ services:
# Set the redis password when startup redis server.
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
healthcheck:
test: ['CMD', 'redis-cli', 'ping']
test: [ 'CMD', 'redis-cli', 'ping' ]
# The DifySandbox
sandbox:
@ -534,7 +539,7 @@ services:
volumes:
- ./volumes/sandbox/dependencies:/dependencies
healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health']
test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
networks:
- ssrf_proxy_network
@ -571,12 +576,7 @@ services:
volumes:
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
entrypoint:
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
environment:
# pls clearly modify the squid env vars to fit your network environment.
HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
@ -605,8 +605,8 @@ services:
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
entrypoint: ['/docker-entrypoint.sh']
command: ['tail', '-f', '/dev/null']
entrypoint: [ '/docker-entrypoint.sh' ]
command: [ 'tail', '-f', '/dev/null' ]
# The nginx reverse proxy.
# used for reverse proxying the API service and Web service.
@ -623,12 +623,7 @@ services:
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
- ./volumes/certbot/conf:/etc/letsencrypt
- ./volumes/certbot/www:/var/www/html
entrypoint:
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
environment:
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
@ -720,7 +715,7 @@ services:
working_dir: /opt/couchbase
stdin_open: true
tty: true
entrypoint: [""]
entrypoint: [ "" ]
command: sh -c "/opt/couchbase/init/init-cbserver.sh"
volumes:
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
@ -749,7 +744,7 @@ services:
volumes:
- ./volumes/pgvector/data:/var/lib/postgresql/data
healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s
timeout: 3s
retries: 30
@ -771,7 +766,7 @@ services:
volumes:
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s
timeout: 3s
retries: 30
@ -836,7 +831,7 @@ services:
- ./volumes/milvus/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
healthcheck:
test: ['CMD', 'etcdctl', 'endpoint', 'health']
test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
interval: 30s
timeout: 20s
retries: 3
@ -855,7 +850,7 @@ services:
- ./volumes/milvus/minio:/minio_data
command: minio server /minio_data --console-address ":9001"
healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live']
test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
interval: 30s
timeout: 20s
retries: 3
@ -864,10 +859,10 @@ services:
milvus-standalone:
container_name: milvus-standalone
image: milvusdb/milvus:v2.3.1
image: milvusdb/milvus:v2.5.0-beta
profiles:
- milvus
command: ['milvus', 'run', 'standalone']
command: [ 'milvus', 'run', 'standalone' ]
environment:
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
@ -875,7 +870,7 @@ services:
volumes:
- ./volumes/milvus/milvus:/var/lib/milvus
healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz']
test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
interval: 30s
start_period: 90s
timeout: 20s
@ -948,22 +943,30 @@ services:
container_name: elasticsearch
profiles:
- elasticsearch
- elasticsearch-ja
restart: always
volumes:
- ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh
- dify_es01_data:/usr/share/elasticsearch/data
environment:
ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
VECTOR_STORE: ${VECTOR_STORE:-}
cluster.name: dify-es-cluster
node.name: dify-es0
discovery.type: single-node
xpack.license.self_generated.type: trial
xpack.license.self_generated.type: basic
xpack.security.enabled: 'true'
xpack.security.enrollment.enabled: 'false'
xpack.security.http.ssl.enabled: 'false'
ports:
- ${ELASTICSEARCH_PORT:-9200}:9200
deploy:
resources:
limits:
memory: 2g
entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ]
healthcheck:
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty']
test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
interval: 30s
timeout: 10s
retries: 50
@ -991,7 +994,7 @@ services:
ports:
- ${KIBANA_PORT:-5601}:5601
healthcheck:
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1']
test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
interval: 30s
timeout: 10s
retries: 3

View File

@ -0,0 +1,25 @@
#!/bin/bash
set -e
if [ "${VECTOR_STORE}" = "elasticsearch-ja" ]; then
# Check if the ICU tokenizer plugin is installed
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-icu; then
printf '%s\n' "Installing the ICU tokenizer plugin"
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-icu; then
printf '%s\n' "Failed to install the ICU tokenizer plugin"
exit 1
fi
fi
# Check if the Japanese language analyzer plugin is installed
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-kuromoji; then
printf '%s\n' "Installing the Japanese language analyzer plugin"
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-kuromoji; then
printf '%s\n' "Failed to install the Japanese language analyzer plugin"
exit 1
fi
fi
fi
# Run the original entrypoint script
exec /bin/tini -- /usr/local/bin/docker-entrypoint.sh

View File

@ -25,3 +25,6 @@ NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
NEXT_PUBLIC_CSP_WHITELIST=
# The maximum number of top-k value for RAG.
NEXT_PUBLIC_TOP_K_MAX_VALUE=10

View File

@ -25,16 +25,18 @@ import Input from '@/app/components/base/input'
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
import TagManagementModal from '@/app/components/base/tag-management'
import TagFilter from '@/app/components/base/tag-management/filter'
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
const getKey = (
pageIndex: number,
previousPageData: AppListResponse,
activeTab: string,
isCreatedByMe: boolean,
tags: string[],
keywords: string,
) => {
if (!pageIndex || previousPageData.has_more) {
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords } }
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } }
if (activeTab !== 'all')
params.params.mode = activeTab
@ -58,6 +60,7 @@ const Apps = () => {
defaultTab: 'all',
})
const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState()
const [isCreatedByMe, setIsCreatedByMe] = useState(false)
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
const [searchKeywords, setSearchKeywords] = useState(keywords)
const setKeywords = useCallback((keywords: string) => {
@ -68,7 +71,7 @@ const Apps = () => {
}, [setQuery])
const { data, isLoading, setSize, mutate } = useSWRInfinite(
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords),
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords),
fetchAppList,
{ revalidateFirstPage: true },
)
@ -132,6 +135,12 @@ const Apps = () => {
options={options}
/>
<div className='flex items-center gap-2'>
<CheckboxWithLabel
className='mr-2'
label={t('app.showMyCreatedAppsOnly')}
isChecked={isCreatedByMe}
onChange={() => setIsCreatedByMe(!isCreatedByMe)}
/>
<TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} />
<Input
showLeftIcon

View File

@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
- <code>economy</code> Economy: Build using inverted index of keyword table index
</Property>
<Property name='doc_form' type='string' key='doc_form'>
Format of indexed content
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
- <code>hierarchical_model</code> Parent-child mode
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
</Property>
<Property name='doc_language' type='string' key='doc_language'>
In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
</Property>
<Property name='process_rule' type='object' key='process_rule'>
Processing rules
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
@ -65,6 +74,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property>
</Properties>
</Col>
@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
- <code>economy</code> Economy: Build using inverted index of keyword table index
- <code>doc_form</code> Format of indexed content
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
- <code>hierarchical_model</code> Parent-child mode
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
- <code>doc_language</code> In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
- <code>process_rule</code> Processing rules
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
- <code>rules</code> (object) Custom rules (in automatic mode, this field is empty)
@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property>
<Property name='file' type='multipart/form-data' key='file'>
Files that need to be uploaded.
@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property>
</Properties>
</Col>
@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property>
</Properties>
</Col>
@ -984,7 +1020,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
method='POST'
title='Update a Chunk in a Document '
title='Update a Chunk in a Document'
name='#update_segment'
/>
<Row>
@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional)
- <code>keywords</code> (list) Keyword (optional)
- <code>enabled</code> (bool) False / true (optional)
- <code>regenerate_child_chunks</code> (bool) Whether to regenerate child chunks (optional)
</Property>
</Properties>
</Col>

View File

@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
</Property>
<Property name='doc_form' type='string' key='doc_form'>
索引内容的形式
- <code>text_model</code> text 文档直接 embedding经济模式默认为该模式
- <code>hierarchical_model</code> parent-child 模式
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
</Property>
<Property name='doc_language' type='string' key='doc_language'>
在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
</Property>
<Property name='process_rule' type='object' key='process_rule'>
处理规则
- <code>mode</code> (string) 清洗、分段模式 automatic 自动 / custom 自定义
@ -63,8 +72,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>remove_urls_emails</code> 删除 URL、电子邮件地址
- <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值
- <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 <code>\n</code>
- <code>max_tokens</code> 最大长度token默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property>
</Properties>
</Col>
@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
- <code>doc_form</code> 索引内容的形式
- <code>text_model</code> text 文档直接 embedding经济模式默认为该模式
- <code>hierarchical_model</code> parent-child 模式
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
- <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
- <code>process_rule</code> 处理规则
- <code>mode</code> (string) 清洗、分段模式 automatic 自动 / custom 自定义
- <code>rules</code> (object) 自定义规则(自动模式下,该字段为空)
@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度token默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property>
<Property name='file' type='multipart/form-data' key='file'>
需要上传的文件。
@ -411,7 +437,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
method='POST'
title='通过文本更新文档 '
title='通过文本更新文档'
name='#update-by-text'
/>
<Row>
@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度token默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property>
</Properties>
</Col>
@ -508,7 +539,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
method='POST'
title='通过文件更新文档 '
title='通过文件更新文档'
name='#update-by-file'
/>
<Row>
@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度token默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property>
</Properties>
</Col>
@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值
- <code>keywords</code> (list) 关键字,非必填
- <code>enabled</code> (bool) false/true非必填
- <code>regenerate_child_chunks</code> (bool) 是否重新生成子分段,非必填
</Property>
</Properties>
</Col>

View File

@ -26,13 +26,15 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({
const [clientY, setClientY] = useState(0)
const [isResizing, setIsResizing] = useState(false)
const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect)
const [oldHeight, setOldHeight] = useState(height)
const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => {
setClientY(e.clientY)
setIsResizing(true)
setOldHeight(height)
setPrevUserSelectStyle(getComputedStyle(document.body).userSelect)
document.body.style.userSelect = 'none'
}, [])
}, [height])
const handleStopResize = useCallback(() => {
setIsResizing(false)
@ -44,8 +46,7 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({
return
const offset = e.clientY - clientY
let newHeight = height + offset
setClientY(e.clientY)
let newHeight = oldHeight + offset
if (newHeight < minHeight)
newHeight = minHeight
onHeightChange(newHeight)

View File

@ -27,6 +27,7 @@ import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/confi
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block'
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useFeaturesStore } from '@/app/components/base/features/hooks'
export type ISimplePromptInput = {
mode: AppType
@ -54,6 +55,11 @@ const Prompt: FC<ISimplePromptInput> = ({
const { t } = useTranslation()
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
const featuresStore = useFeaturesStore()
const {
features,
setFeatures,
} = featuresStore!.getState()
const { eventEmitter } = useEventEmitterContextContext()
const {
@ -137,8 +143,18 @@ const Prompt: FC<ISimplePromptInput> = ({
})
setModelConfig(newModelConfig)
setPrevPromptConfig(modelConfig.configs)
if (mode !== AppType.completion)
if (mode !== AppType.completion) {
setIntroduction(res.opening_statement)
const newFeatures = produce(features, (draft) => {
draft.opening = {
...draft.opening,
enabled: !!res.opening_statement,
opening_statement: res.opening_statement,
}
})
setFeatures(newFeatures)
}
showAutomaticFalse()
}
const minHeight = initEditorHeight || 228

View File

@ -59,36 +59,24 @@ const ConfigContent: FC<Props> = ({
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: currentRerankModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
{
provider: datasetConfigs.reranking_model?.reranking_provider_name,
model: datasetConfigs.reranking_model?.reranking_model_name,
},
)
const rerankModel = (() => {
if (datasetConfigs.reranking_model?.reranking_provider_name) {
return {
provider_name: datasetConfigs.reranking_model.reranking_provider_name,
model_name: datasetConfigs.reranking_model.reranking_model_name,
}
const rerankModel = useMemo(() => {
return {
provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '',
model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '',
}
else if (rerankDefaultModel) {
return {
provider_name: rerankDefaultModel.provider.provider,
model_name: rerankDefaultModel.model,
}
}
})()
}, [datasetConfigs.reranking_model])
const handleParamChange = (key: string, value: number) => {
if (key === 'top_k') {
@ -133,6 +121,12 @@ const ConfigContent: FC<Props> = ({
}
const handleRerankModeChange = (mode: RerankingModeEnum) => {
if (mode === datasetConfigs.reranking_mode)
return
if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
onChange({
...datasetConfigs,
reranking_mode: mode,
@ -162,31 +156,25 @@ const ConfigContent: FC<Props> = ({
const canManuallyToggleRerank = useMemo(() => {
return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|| selectedDatasetsMode.allExternal
|| selectedDatasetsMode.allExternal
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
const showRerankModel = useMemo(() => {
if (!canManuallyToggleRerank)
return true
else if (canManuallyToggleRerank && !isRerankDefaultModelValid)
return false
return datasetConfigs.reranking_enable
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
const handleDisabledSwitchClick = useCallback(() => {
if (!currentRerankModel && !showRerankModel)
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
if (!currentRerankModel && enable)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentRerankModel, showRerankModel, t])
useEffect(() => {
if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
onChange({
...datasetConfigs,
reranking_enable: showRerankModel,
})
}
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
onChange({
...datasetConfigs,
reranking_enable: enable,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentRerankModel, datasetConfigs, onChange])
return (
<div>
@ -267,24 +255,12 @@ const ConfigContent: FC<Props> = ({
<div className='flex items-center'>
{
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
<div
className='flex items-center'
onClick={handleDisabledSwitchClick}
>
<Switch
size='md'
defaultValue={showRerankModel}
disabled={!currentRerankModel || !canManuallyToggleRerank}
onChange={(v) => {
if (canManuallyToggleRerank) {
onChange({
...datasetConfigs,
reranking_enable: v,
})
}
}}
/>
</div>
<Switch
size='md'
defaultValue={showRerankModel}
disabled={!canManuallyToggleRerank}
onChange={handleDisabledSwitchClick}
/>
)
}
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
@ -298,21 +274,24 @@ const ConfigContent: FC<Props> = ({
triggerClassName='ml-1 w-4 h-4'
/>
</div>
<div>
<ModelSelector
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
onSelect={(v) => {
onChange({
...datasetConfigs,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
modelList={rerankModelList}
/>
</div>
{
showRerankModel && (
<div>
<ModelSelector
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
onSelect={(v) => {
onChange({
...datasetConfigs,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
modelList={rerankModelList}
/>
</div>
)}
</div>
)
}

View File

@ -10,7 +10,7 @@ import Modal from '@/app/components/base/modal'
import Button from '@/app/components/base/button'
import { RETRIEVE_TYPE } from '@/types/app'
import Toast from '@/app/components/base/toast'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { RerankingModeEnum } from '@/models/datasets'
import type { DataSet } from '@/models/datasets'
@ -41,17 +41,27 @@ const ParamsConfig = ({
}, [datasetConfigs])
const {
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
modelList: rerankModelList,
currentModel: rerankDefaultModel,
currentProvider: rerankDefaultProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: isCurrentRerankModelValid,
} = useCurrentProviderAndModel(
rerankModelList,
{
provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '',
model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '',
},
)
const isValid = () => {
let errMsg = ''
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
if (tempDataSetConfigs.reranking_enable
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
&& !isRerankDefaultModelValid
&& !isCurrentRerankModelValid
)
errMsg = t('appDebug.datasetConfig.rerankModelRequired')
}
@ -66,16 +76,7 @@ const ParamsConfig = ({
const handleSave = () => {
if (!isValid())
return
const config = { ...tempDataSetConfigs }
if (config.retrieval_model === RETRIEVE_TYPE.multiWay
&& config.reranking_mode === RerankingModeEnum.RerankingModel
&& !config.reranking_model) {
config.reranking_model = {
reranking_provider_name: rerankDefaultModel?.provider?.provider,
reranking_model_name: rerankDefaultModel?.model,
} as any
}
setDatasetConfigs(config)
setDatasetConfigs(tempDataSetConfigs)
setRerankSettingModalOpen(false)
}
@ -94,14 +95,14 @@ const ParamsConfig = ({
reranking_enable: restConfigs.reranking_enable,
}, selectedDatasets, selectedDatasets, {
provider: rerankDefaultProvider?.provider,
model: isRerankDefaultModelValid?.model,
model: rerankDefaultModel?.model,
})
setTempDataSetConfigs({
...retrievalConfig,
reranking_model: restConfigs.reranking_model && {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
reranking_model: {
reranking_provider_name: retrievalConfig.reranking_model?.provider || '',
reranking_model_name: retrievalConfig.reranking_model?.model || '',
},
retrieval_model,
score_threshold_enabled,

View File

@ -29,7 +29,7 @@ const WeightedScore = ({
return (
<div>
<div className='px-3 pt-5 h-[52px] space-x-3 rounded-lg border border-components-panel-border'>
<div className='px-3 pt-5 pb-2 space-x-3 rounded-lg border border-components-panel-border'>
<Slider
className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')}
max={1.0}
@ -39,7 +39,7 @@ const WeightedScore = ({
onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })}
trackClassName='weightedScoreSliderTrack'
/>
<div className='flex justify-between mt-1'>
<div className='flex justify-between mt-3'>
<div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'>
<div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}>
{t('dataset.weightedScore.semantic')}

View File

@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import { type DataSet, RerankingModeEnum } from '@/models/datasets'
import { type DataSet } from '@/models/datasets'
import { useToastContext } from '@/app/components/base/toast'
import { updateDatasetSetting } from '@/service/datasets'
import { useAppContext } from '@/context/app-context'
@ -21,7 +21,7 @@ import type { RetrievalConfig } from '@/types/app'
import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import PermissionSelector from '@/app/components/datasets/settings/permission-selector'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
@ -99,8 +99,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
}
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
rerankModelList,
retrievalConfig,
indexMethod,
@ -109,14 +107,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return
}
const postRetrievalConfig = ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig: {
...retrievalConfig,
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
},
indexMethod,
})
try {
setLoading(true)
const { id, name, description, permission } = localeCurrentDataset
@ -128,8 +118,8 @@ const SettingsModal: FC<SettingsModalProps> = ({
permission,
indexing_technique: indexMethod,
retrieval_model: {
...postRetrievalConfig,
score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0,
...retrievalConfig,
score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0,
},
embedding_model: localeCurrentDataset.embedding_model,
embedding_model_provider: localeCurrentDataset.embedding_model_provider,
@ -157,7 +147,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
onSave({
...localeCurrentDataset,
indexing_technique: indexMethod,
retrieval_model_dict: postRetrievalConfig,
retrieval_model_dict: retrievalConfig,
})
}
catch (e) {

View File

@ -287,9 +287,9 @@ const Configuration: FC = () => {
setDatasetConfigs({
...retrievalConfig,
reranking_model: restConfigs.reranking_model && {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
reranking_model: {
reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
reranking_model_name: retrievalConfig?.reranking_model?.model || '',
},
retrieval_model,
score_threshold_enabled,

View File

@ -39,6 +39,7 @@ type ChatInputAreaProps = {
inputs?: Record<string, any>
inputsForm?: InputForm[]
theme?: Theme | null
isResponding?: boolean
}
const ChatInputArea = ({
showFeatureBar,
@ -51,6 +52,7 @@ const ChatInputArea = ({
inputs = {},
inputsForm = [],
theme,
isResponding,
}: ChatInputAreaProps) => {
const { t } = useTranslation()
const { notify } = useToastContext()
@ -77,6 +79,11 @@ const ChatInputArea = ({
const historyRef = useRef([''])
const [currentIndex, setCurrentIndex] = useState(-1)
const handleSend = () => {
if (isResponding) {
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
return
}
if (onSend) {
const { files, setFiles } = filesStore.getState()
if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) {
@ -116,7 +123,7 @@ const ChatInputArea = ({
setQuery(historyRef.current[currentIndex + 1])
}
else if (currentIndex === historyRef.current.length - 1) {
// If it is the last element, clear the input box
// If it is the last element, clear the input box
setCurrentIndex(historyRef.current.length)
setQuery('')
}
@ -169,6 +176,7 @@ const ChatInputArea = ({
'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none',
)}
placeholder={t('common.chat.inputPlaceholder') || ''}
autoFocus
autoSize={{ minRows: 1 }}
onResize={handleTextareaResize}
value={query}

View File

@ -292,6 +292,7 @@ const Chat: FC<ChatProps> = ({
inputs={inputs}
inputsForm={inputsForm}
theme={themeBuilder?.theme}
isResponding={isResponding}
/>
)
}

View File

@ -28,8 +28,8 @@ const Question: FC<QuestionProps> = ({
} = item
return (
<div className='flex justify-end mb-2 last:mb-0 pl-10'>
<div className='group relative mr-4'>
<div className='flex justify-end mb-2 last:mb-0 pl-14'>
<div className='group relative mr-4 max-w-full'>
<div
className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900'
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}

View File

@ -111,9 +111,9 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }
}
else if (language === 'echarts') {
return (
<div style={{ minHeight: '350px', minWidth: '700px' }}>
<div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}>
<ErrorBoundary>
<ReactEcharts option={chartData} />
<ReactEcharts option={chartData} style={{ minWidth: '700px' }} />
</ErrorBoundary>
</div>
)

View File

@ -11,11 +11,17 @@ type Props = {
enable: boolean
}
const maxTopK = (() => {
const configValue = parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10)
if (configValue && !isNaN(configValue))
return configValue
return 10
})()
const VALUE_LIMIT = {
default: 2,
step: 1,
min: 1,
max: 10,
max: maxTopK,
}
const key = 'top_k'

View File

@ -6,14 +6,10 @@ import type {
import { RerankingModeEnum } from '@/models/datasets'
export const isReRankModelSelected = ({
rerankDefaultModel,
isRerankDefaultModelValid,
retrievalConfig,
rerankModelList,
indexMethod,
}: {
rerankDefaultModel?: DefaultModelResponse
isRerankDefaultModelValid: boolean
retrievalConfig: RetrievalConfig
rerankModelList: Model[]
indexMethod?: string
@ -25,12 +21,17 @@ export const isReRankModelSelected = ({
return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
}
if (isRerankDefaultModelValid)
return !!rerankDefaultModel
return false
})()
if (
indexMethod === 'high_quality'
&& ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrievalConfig.search_method))
&& retrievalConfig.reranking_enable
&& !rerankModelSelected
)
return false
if (
indexMethod === 'high_quality'
&& (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore)

View File

@ -10,11 +10,13 @@ import { RETRIEVE_METHOD } from '@/types/app'
import type { RetrievalConfig } from '@/types/app'
type Props = {
disabled?: boolean
value: RetrievalConfig
onChange: (value: RetrievalConfig) => void
}
const EconomicalRetrievalMethodConfig: FC<Props> = ({
disabled = false,
value,
onChange,
}) => {
@ -22,7 +24,8 @@ const EconomicalRetrievalMethodConfig: FC<Props> = ({
return (
<div className='space-y-2'>
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
<OptionCard
disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
title={t('dataset.retrieval.invertedIndex.title')}
description={t('dataset.retrieval.invertedIndex.description')} isActive
activeHeaderClassName='bg-dataset-option-card-purple-gradient'

View File

@ -1,6 +1,6 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import React, { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import Image from 'next/image'
import RetrievalParamConfig from '../retrieval-param-config'
@ -10,7 +10,7 @@ import { retrievalIcon } from '../../create/icons'
import type { RetrievalConfig } from '@/types/app'
import { RETRIEVE_METHOD } from '@/types/app'
import { useProviderContext } from '@/context/provider-context'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import {
DEFAULT_WEIGHTED_SCORE,
@ -20,54 +20,87 @@ import {
import Badge from '@/app/components/base/badge'
type Props = {
disabled?: boolean
value: RetrievalConfig
onChange: (value: RetrievalConfig) => void
}
const RetrievalMethodConfig: FC<Props> = ({
value: passValue,
disabled = false,
value,
onChange,
}) => {
const { t } = useTranslation()
const { supportRetrievalMethods } = useProviderContext()
const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank)
const value = (() => {
if (!passValue.reranking_model.reranking_model_name) {
return {
...passValue,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.provider.provider || '',
reranking_model_name: rerankDefaultModel?.model || '',
},
reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore),
weights: passValue.weights || {
weight_type: WeightedScoreEnum.Customized,
vector_setting: {
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
embedding_provider_name: '',
embedding_model_name: '',
},
keyword_setting: {
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
},
},
}
const {
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const onSwitch = useCallback((retrieveMethod: RETRIEVE_METHOD) => {
if ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrieveMethod)) {
onChange({
...value,
search_method: retrieveMethod,
...(!value.reranking_model.reranking_model_name
? {
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
reranking_enable: !!isRerankDefaultModelValid,
}
: {
reranking_enable: true,
}),
})
}
return passValue
})()
if (retrieveMethod === RETRIEVE_METHOD.hybrid) {
onChange({
...value,
search_method: retrieveMethod,
...(!value.reranking_model.reranking_model_name
? {
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
reranking_enable: !!isRerankDefaultModelValid,
reranking_mode: isRerankDefaultModelValid ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore,
}
: {
reranking_enable: true,
reranking_mode: RerankingModeEnum.RerankingModel,
}),
...(!value.weights
? {
weights: {
weight_type: WeightedScoreEnum.Customized,
vector_setting: {
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
embedding_provider_name: '',
embedding_model_name: '',
},
keyword_setting: {
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
},
},
}
: {}),
})
}
}, [value, rerankDefaultModel, isRerankDefaultModelValid, onChange])
return (
<div className='space-y-2'>
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
title={t('dataset.retrieval.semantic_search.title')}
description={t('dataset.retrieval.semantic_search.description')}
isActive={
value.search_method === RETRIEVE_METHOD.semantic
}
onSwitched={() => onChange({
...value,
search_method: RETRIEVE_METHOD.semantic,
})}
onSwitched={() => onSwitch(RETRIEVE_METHOD.semantic)}
effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
>
@ -78,17 +111,14 @@ const RetrievalMethodConfig: FC<Props> = ({
/>
</OptionCard>
)}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.fullText) && (
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
title={t('dataset.retrieval.full_text_search.title')}
description={t('dataset.retrieval.full_text_search.description')}
isActive={
value.search_method === RETRIEVE_METHOD.fullText
}
onSwitched={() => onChange({
...value,
search_method: RETRIEVE_METHOD.fullText,
})}
onSwitched={() => onSwitch(RETRIEVE_METHOD.fullText)}
effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
>
@ -99,8 +129,8 @@ const RetrievalMethodConfig: FC<Props> = ({
/>
</OptionCard>
)}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.hybrid) && (
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
title={
<div className='flex items-center space-x-1'>
<div>{t('dataset.retrieval.hybrid_search.title')}</div>
@ -110,11 +140,7 @@ const RetrievalMethodConfig: FC<Props> = ({
description={t('dataset.retrieval.hybrid_search.description')} isActive={
value.search_method === RETRIEVE_METHOD.hybrid
}
onSwitched={() => onChange({
...value,
search_method: RETRIEVE_METHOD.hybrid,
reranking_enable: true,
})}
onSwitched={() => onSwitch(RETRIEVE_METHOD.hybrid)}
effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
>

View File

@ -1,6 +1,6 @@
'use client'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import React, { useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Image from 'next/image'
@ -39,8 +39,8 @@ const RetrievalParamConfig: FC<Props> = ({
const { t } = useTranslation()
const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid
const isEconomical = type === RETRIEVE_METHOD.invertedIndex
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
const {
defaultModel: rerankDefaultModel,
modelList: rerankModelList,
} = useModelListAndDefaultModel(ModelTypeEnum.rerank)
@ -48,35 +48,28 @@ const RetrievalParamConfig: FC<Props> = ({
currentModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
{
provider: value.reranking_model?.reranking_provider_name ?? '',
model: value.reranking_model?.reranking_model_name ?? '',
},
)
const handleDisabledSwitchClick = useCallback(() => {
if (!currentModel)
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
if (enable && !currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentModel, rerankDefaultModel, t])
onChange({
...value,
reranking_enable: enable,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentModel, onChange, value])
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
const rerankModel = (() => {
if (value.reranking_model) {
return {
provider_name: value.reranking_model.reranking_provider_name,
model_name: value.reranking_model.reranking_model_name,
}
const rerankModel = useMemo(() => {
return {
provider_name: value.reranking_model.reranking_provider_name,
model_name: value.reranking_model.reranking_model_name,
}
else if (rerankDefaultModel) {
return {
provider_name: rerankDefaultModel.provider.provider,
model_name: rerankDefaultModel.model,
}
}
})()
}, [value.reranking_model])
const handleChangeRerankMode = (v: RerankingModeEnum) => {
if (v === value.reranking_mode)
@ -100,6 +93,8 @@ const RetrievalParamConfig: FC<Props> = ({
},
}
}
if (v === RerankingModeEnum.RerankingModel && !currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
onChange(result)
}
@ -122,22 +117,11 @@ const RetrievalParamConfig: FC<Props> = ({
<div>
<div className='flex items-center space-x-2 mb-2'>
{canToggleRerankModalEnable && (
<div
className='flex items-center'
onClick={handleDisabledSwitchClick}
>
<Switch
size='md'
defaultValue={currentModel ? value.reranking_enable : false}
onChange={(v) => {
onChange({
...value,
reranking_enable: v,
})
}}
disabled={!currentModel}
/>
</div>
<Switch
size='md'
defaultValue={value.reranking_enable}
onChange={handleDisabledSwitchClick}
/>
)}
<div className='flex items-center'>
<span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span>
@ -148,21 +132,23 @@ const RetrievalParamConfig: FC<Props> = ({
/>
</div>
</div>
<ModelSelector
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
modelList={rerankModelList}
readonly={!value.reranking_enable}
onSelect={(v) => {
onChange({
...value,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
/>
{
value.reranking_enable && (
<ModelSelector
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
modelList={rerankModelList}
onSelect={(v) => {
onChange({
...value,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
/>
)
}
</div>
)}
{
@ -255,10 +241,8 @@ const RetrievalParamConfig: FC<Props> = ({
{
value.reranking_mode !== RerankingModeEnum.WeightedScore && (
<ModelSelector
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
modelList={rerankModelList}
readonly={!value.reranking_enable}
onSelect={(v) => {
onChange({
...value,

View File

@ -30,6 +30,7 @@ import { useProviderContext } from '@/context/provider-context'
import { sleep } from '@/utils'
import { RETRIEVE_METHOD } from '@/types/app'
import Tooltip from '@/app/components/base/tooltip'
import { useInvalidDocumentList } from '@/service/knowledge/use-document'
type Props = {
datasetId: string
@ -207,7 +208,9 @@ const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], index
})
const router = useRouter()
const invalidDocumentList = useInvalidDocumentList()
const navToDocumentList = () => {
invalidDocumentList()
router.push(`/datasets/${datasetId}/documents`)
}
const navToApiDocs = () => {

View File

@ -31,17 +31,17 @@ import LanguageSelect from './language-select'
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
import cn from '@/utils/classnames'
import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
import { ChunkingMode, DataSourceType, ProcessMode } from '@/models/datasets'
import Button from '@/app/components/base/button'
import FloatRightContainer from '@/app/components/base/float-right-container'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import { type RetrievalConfig } from '@/types/app'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import Toast from '@/app/components/base/toast'
import type { NotionPage } from '@/models/common'
import { DataSourceProvider } from '@/models/common'
import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets'
import { useDatasetDetailContext } from '@/context/dataset-detail'
import I18n from '@/context/i18n'
import { RETRIEVE_METHOD } from '@/types/app'
@ -53,7 +53,7 @@ import type { DefaultModel } from '@/app/components/header/account-setting/model
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import Checkbox from '@/app/components/base/checkbox'
import RadioCard from '@/app/components/base/radio-card'
import { IS_CE_EDITION } from '@/config'
import { FULL_DOC_PREVIEW_LENGTH, IS_CE_EDITION } from '@/config'
import Divider from '@/app/components/base/divider'
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset'
import Badge from '@/app/components/base/badge'
@ -90,17 +90,13 @@ type StepTwoProps = {
onCancel?: () => void
}
export enum SegmentType {
AUTO = 'automatic',
CUSTOM = 'custom',
}
export enum IndexingType {
QUALIFIED = 'high_quality',
ECONOMICAL = 'economy',
}
const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500
const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500
const DEFAULT_OVERLAP = 50
type ParentChildConfig = {
@ -131,7 +127,6 @@ const StepTwo = ({
isSetting,
documentDetail,
isAPIKeySet,
onSetting,
datasetId,
indexingType,
dataSourceType: inCreatePageDataSourceType,
@ -162,12 +157,12 @@ const StepTwo = ({
const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type)
const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type
const [segmentationType, setSegmentationType] = useState<SegmentType>(SegmentType.CUSTOM)
const [segmentationType, setSegmentationType] = useState<ProcessMode>(ProcessMode.general)
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER))
}, [])
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH) // default chunk length
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000)
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
const [rules, setRules] = useState<PreProcessingRule[]>([])
@ -198,7 +193,6 @@ const StepTwo = ({
)
// QA Related
const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false)
const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false)
const [docForm, setDocForm] = useState<ChunkingMode>(
(datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text,
@ -348,7 +342,7 @@ const StepTwo = ({
}
const updatePreview = () => {
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) {
if (segmentationType === ProcessMode.general && maxChunkLength > 4000) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') })
return
}
@ -373,13 +367,42 @@ const StepTwo = ({
model: defaultEmbeddingModel?.model || '',
},
)
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
} as RetrievalConfig)
useEffect(() => {
if (currentDataset?.retrieval_model_dict)
return
setRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: !!isRerankDefaultModelValid,
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [rerankDefaultModel, isRerankDefaultModelValid])
const getCreationParams = () => {
let params
if (segmentationType === SegmentType.CUSTOM && overlap > maxChunkLength) {
if (segmentationType === ProcessMode.general && overlap > maxChunkLength) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') })
return
}
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) {
if (segmentationType === ProcessMode.general && maxChunkLength > limitMaxChunkLength) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) })
return
}
@ -389,7 +412,6 @@ const StepTwo = ({
doc_form: currentDocForm,
doc_language: docLanguage,
process_rule: getProcessRule(),
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page.
embedding_model: embeddingModel.model, // Readonly
embedding_model_provider: embeddingModel.provider, // Readonly
@ -400,10 +422,7 @@ const StepTwo = ({
const indexMethod = getIndexing_technique()
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
rerankModelList,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrievalConfig,
indexMethod: indexMethod as string,
})
@ -411,16 +430,6 @@ const StepTwo = ({
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return
}
const postRetrievalConfig = ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig: {
// eslint-disable-next-line @typescript-eslint/no-use-before-define
...retrievalConfig,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
},
indexMethod: indexMethod as string,
})
params = {
data_source: {
type: dataSourceType,
@ -432,8 +441,7 @@ const StepTwo = ({
process_rule: getProcessRule(),
doc_form: currentDocForm,
doc_language: docLanguage,
retrieval_model: postRetrievalConfig,
retrieval_model: retrievalConfig,
embedding_model: embeddingModel.model,
embedding_model_provider: embeddingModel.provider,
} as CreateDocumentReq
@ -490,7 +498,6 @@ const StepTwo = ({
const getDefaultMode = () => {
if (documentDetail)
// @ts-expect-error fix after api refactored
setSegmentationType(documentDetail.dataset_process_rule.mode)
}
@ -525,7 +532,6 @@ const StepTwo = ({
onSuccess(data) {
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
updateResultCache && updateResultCache(data)
// eslint-disable-next-line @typescript-eslint/no-use-before-define
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
},
},
@ -545,14 +551,6 @@ const StepTwo = ({
isSetting && onSave && onSave()
}
const changeToEconomicalType = () => {
if (docForm !== ChunkingMode.text)
return
if (!hasSetIndexType)
setIndexType(IndexingType.ECONOMICAL)
}
useEffect(() => {
// fetch rules
if (!isSetting) {
@ -574,18 +572,6 @@ const StepTwo = ({
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
}, [isAPIKeySet, indexingType, datasetId])
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.provider.provider,
reranking_model_name: rerankDefaultModel?.model,
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
} as RetrievalConfig)
const economyDomRef = useRef<HTMLDivElement>(null)
const isHoveringEconomy = useHover(economyDomRef)
@ -946,6 +932,7 @@ const StepTwo = ({
<div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div>
<ModelSelector
readonly={!!datasetId}
triggerClassName={datasetId ? 'opacity-50' : ''}
defaultModel={embeddingModel}
modelList={embeddingModelList}
onSelect={(model: DefaultModel) => {
@ -984,12 +971,14 @@ const StepTwo = ({
getIndexing_technique() === IndexingType.QUALIFIED
? (
<RetrievalMethodConfig
disabled={!!datasetId}
value={retrievalConfig}
onChange={setRetrievalConfig}
/>
)
: (
<EconomicalRetrievalMethodConfig
disabled={!!datasetId}
value={retrievalConfig}
onChange={setRetrievalConfig}
/>
@ -1010,7 +999,7 @@ const StepTwo = ({
)
: (
<div className='flex items-center mt-8 py-2'>
<Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>
{!datasetId && <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>}
<Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button>
</div>
)}
@ -1081,11 +1070,11 @@ const StepTwo = ({
}
{
currentDocForm !== ChunkingMode.qa
&& <Badge text={t(
'datasetCreation.stepTwo.previewChunkCount', {
count: estimate?.total_segments || 0,
}) as string}
/>
&& <Badge text={t(
'datasetCreation.stepTwo.previewChunkCount', {
count: estimate?.total_segments || 0,
}) as string}
/>
}
</div>
</PreviewHeader>}
@ -1117,6 +1106,9 @@ const StepTwo = ({
{currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && (
estimate?.preview?.map((item, index) => {
const indexForLabel = index + 1
const childChunks = parentChildConfig.chunkForContext === 'full-doc'
? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH)
: item.child_chunks
return (
<ChunkContainer
key={item.content}
@ -1124,7 +1116,7 @@ const StepTwo = ({
characterCount={item.content.length}
>
<FormattedText>
{item.child_chunks.map((child, index) => {
{childChunks.map((child, index) => {
const indexForLabel = index + 1
return (
<PreviewSlice

View File

@ -4,7 +4,7 @@ import classNames from '@/utils/classnames'
const TriangleArrow: FC<ComponentProps<'svg'>> = props => (
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}>
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor"/>
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor" />
</svg>
)
@ -65,7 +65,7 @@ export const OptionCard: FC<OptionCardProps> = forwardRef((props, ref) => {
(isActive && !noHighlight)
? 'border-[1.5px] border-components-option-card-option-selected-border'
: 'border border-components-option-card-option-border',
disabled && 'opacity-50 cursor-not-allowed',
disabled && 'opacity-50 pointer-events-none',
className,
)}
style={{

View File

@ -32,6 +32,9 @@ import Checkbox from '@/app/components/base/checkbox'
import {
useChildSegmentList,
useChildSegmentListKey,
useChunkListAllKey,
useChunkListDisabledKey,
useChunkListEnabledKey,
useDeleteChildSegment,
useDeleteSegment,
useDisableSegment,
@ -156,18 +159,18 @@ const Completed: FC<ICompletedProps> = ({
page: isFullDocMode ? 1 : currentPage,
limit: isFullDocMode ? 10 : limit,
keyword: isFullDocMode ? '' : searchValue,
enabled: selectedStatus === 'all' ? 'all' : !!selectedStatus,
enabled: selectedStatus,
},
},
currentPage === 0,
)
const invalidSegmentList = useInvalid(useSegmentListKey)
useEffect(() => {
if (segmentListData) {
setSegments(segmentListData.data || [])
if (segmentListData.total_pages < currentPage)
setCurrentPage(segmentListData.total_pages)
const totalPages = segmentListData.total_pages
if (totalPages < currentPage)
setCurrentPage(totalPages === 0 ? 1 : totalPages)
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [segmentListData])
@ -185,12 +188,12 @@ const Completed: FC<ICompletedProps> = ({
documentId,
segmentId: segments[0]?.id || '',
params: {
page: currentPage,
page: currentPage === 0 ? 1 : currentPage,
limit,
keyword: searchValue,
},
},
!isFullDocMode || segments.length === 0 || currentPage === 0,
!isFullDocMode || segments.length === 0,
)
const invalidChildSegmentList = useInvalid(useChildSegmentListKey)
@ -204,21 +207,20 @@ const Completed: FC<ICompletedProps> = ({
useEffect(() => {
if (childChunkListData) {
setChildSegments(childChunkListData.data || [])
if (childChunkListData.total_pages < currentPage)
setCurrentPage(childChunkListData.total_pages)
const totalPages = childChunkListData.total_pages
if (totalPages < currentPage)
setCurrentPage(totalPages === 0 ? 1 : totalPages)
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [childChunkListData])
const resetList = useCallback(() => {
setSegments([])
setSelectedSegmentIds([])
invalidSegmentList()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
const resetChildList = useCallback(() => {
setChildSegments([])
invalidChildSegmentList()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
@ -232,8 +234,32 @@ const Completed: FC<ICompletedProps> = ({
setFullScreen(false)
}, [])
const onCloseNewSegmentModal = useCallback(() => {
onNewSegmentModalChange(false)
setFullScreen(false)
}, [onNewSegmentModalChange])
const onCloseNewChildChunkModal = useCallback(() => {
setShowNewChildSegmentModal(false)
setFullScreen(false)
}, [])
const { mutateAsync: enableSegment } = useEnableSegment()
const { mutateAsync: disableSegment } = useDisableSegment()
const invalidChunkListAll = useInvalid(useChunkListAllKey)
const invalidChunkListEnabled = useInvalid(useChunkListEnabledKey)
const invalidChunkListDisabled = useInvalid(useChunkListDisabledKey)
const refreshChunkListWithStatusChanged = () => {
switch (selectedStatus) {
case 'all':
invalidChunkListDisabled()
invalidChunkListEnabled()
break
default:
invalidSegmentList()
}
}
const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => {
const operationApi = enable ? enableSegment : disableSegment
@ -245,6 +271,7 @@ const Completed: FC<ICompletedProps> = ({
seg.enabled = enable
}
setSegments([...segments])
refreshChunkListWithStatusChanged()
},
onError: () => {
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
@ -271,6 +298,23 @@ const Completed: FC<ICompletedProps> = ({
const { mutateAsync: updateSegment } = useUpdateSegment()
const refreshChunkListDataWithDetailChanged = () => {
switch (selectedStatus) {
case 'all':
invalidChunkListDisabled()
invalidChunkListEnabled()
break
case true:
invalidChunkListAll()
invalidChunkListDisabled()
break
case false:
invalidChunkListAll()
invalidChunkListEnabled()
break
}
}
const handleUpdateSegment = useCallback(async (
segmentId: string,
question: string,
@ -320,6 +364,7 @@ const Completed: FC<ICompletedProps> = ({
}
}
setSegments([...segments])
refreshChunkListDataWithDetailChanged()
eventEmitter?.emit('update-segment-success')
},
onSettled() {
@ -432,6 +477,7 @@ const Completed: FC<ICompletedProps> = ({
seg.child_chunks?.push(newChildChunk!)
}
setSegments([...segments])
refreshChunkListDataWithDetailChanged()
}
else {
resetChildList()
@ -496,17 +542,10 @@ const Completed: FC<ICompletedProps> = ({
}
}
setSegments([...segments])
refreshChunkListDataWithDetailChanged()
}
else {
for (const childSeg of childSegments) {
if (childSeg.id === childChunkId) {
childSeg.content = res.data.content
childSeg.type = res.data.type
childSeg.word_count = res.data.word_count
childSeg.updated_at = res.data.updated_at
}
}
setChildSegments([...childSegments])
resetChildList()
}
},
onSettled: () => {
@ -544,12 +583,13 @@ const Completed: FC<ICompletedProps> = ({
<SimpleSelect
onSelect={onChangeStatus}
items={statusList.current}
defaultValue={'all'}
defaultValue={selectedStatus === 'all' ? 'all' : selectedStatus ? 1 : 0}
className={s.select}
wrapperClassName='h-fit mr-2'
optionWrapClassName='w-[160px]'
optionClassName='p-0'
renderOption={({ item, selected }) => <StatusItem item={item} selected={selected} />}
notClearable
/>
<Input
showLeftIcon
@ -623,6 +663,7 @@ const Completed: FC<ICompletedProps> = ({
<FullScreenDrawer
isOpen={currSegment.showModal}
fullScreen={fullScreen}
onClose={onCloseSegmentDetail}
>
<SegmentDetail
segInfo={currSegment.segInfo ?? { id: '' }}
@ -636,13 +677,11 @@ const Completed: FC<ICompletedProps> = ({
<FullScreenDrawer
isOpen={showNewSegmentModal}
fullScreen={fullScreen}
onClose={onCloseNewSegmentModal}
>
<NewSegment
docForm={docForm}
onCancel={() => {
onNewSegmentModalChange(false)
setFullScreen(false)
}}
onCancel={onCloseNewSegmentModal}
onSave={resetList}
viewNewlyAddedChunk={viewNewlyAddedChunk}
/>
@ -651,6 +690,7 @@ const Completed: FC<ICompletedProps> = ({
<FullScreenDrawer
isOpen={currChildChunk.showModal}
fullScreen={fullScreen}
onClose={onCloseChildSegmentDetail}
>
<ChildSegmentDetail
chunkId={currChunkId}
@ -664,13 +704,11 @@ const Completed: FC<ICompletedProps> = ({
<FullScreenDrawer
isOpen={showNewChildSegmentModal}
fullScreen={fullScreen}
onClose={onCloseNewChildChunkModal}
>
<NewChildSegment
chunkId={currChunkId}
onCancel={() => {
setShowNewChildSegmentModal(false)
setFullScreen(false)
}}
onCancel={onCloseNewChildChunkModal}
onSave={onSaveNewChildChunk}
viewNewlyAddedChildChunk={viewNewlyAddedChildChunk}
/>

Some files were not shown because too many files have changed in this diff Show More