Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly 2025-01-08 20:36:22 +08:00
commit fb309462ad
203 changed files with 3469 additions and 2327 deletions

View File

@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes # Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60 ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# celery configuration # celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

View File

@ -14,7 +14,10 @@ if is_db_command():
app = create_migrations_app() app = create_migrations_app()
else: else:
if os.environ.get("FLASK_DEBUG", "False") != "True": # It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey # type: ignore from gevent import monkey # type: ignore
# gevent # gevent

View File

@ -546,6 +546,11 @@ class AuthConfig(BaseSettings):
default=60, default=60,
) )
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
description="Expiration time for refresh tokens in days",
default=30,
)
LOGIN_LOCKOUT_DURATION: PositiveInt = Field( LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400, default=86400,
@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings):
default=4000, default=4000,
) )
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
description="Maximum number of child chunks to preview",
default=50,
)
class MultiModalTransferConfig(BaseSettings): class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(

View File

@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
description="Name of the Milvus database to connect to (default is 'default')", description="Name of the Milvus database to connect to (default is 'default')",
default="default", default="default",
) )
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.14.2", default="0.15.0",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

View File

@ -57,12 +57,13 @@ class AppListApi(Resource):
) )
parser.add_argument("name", type=str, location="args", required=False) parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args() args = parser.parse_args()
# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}

View File

@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -273,8 +273,7 @@ FROM
messages m messages m
ON c.id = m.conversation_id ON c.id = m.conversation_id
WHERE WHERE
c.override_model_configs IS NULL c.app_id = :app_id"""
AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)

View File

@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE | VectorType.MYSCALE
| VectorType.ORACLE | VectorType.ORACLE
| VectorType.ELASTICSEARCH | VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR | VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT | VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM | VectorType.LINDORM
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MYSCALE | VectorType.MYSCALE
| VectorType.ORACLE | VectorType.ORACLE
| VectorType.ELASTICSEARCH | VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE | VectorType.COUCHBASE
| VectorType.PGVECTOR | VectorType.PGVECTOR
| VectorType.LINDORM | VectorType.LINDORM

View File

@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument( parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )

View File

@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper

View File

@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.login import current_user from libs.login import current_user

View File

@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
@ -74,7 +73,7 @@ class CompletionApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
@ -133,7 +132,7 @@ class ChatApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

View File

@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
user=current_user, user=current_user,
source="datasets", source="datasets",
) )
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig(**args)
@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)

View File

@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import UTC, datetime from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
@ -8,6 +8,8 @@ from flask import current_app, request
from flask_login import user_logged_in # type: ignore from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore from flask_restful import Resource # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
return decorator return decorator
def validate_and_get_api_token(scope=None): def validate_and_get_api_token(scope: str | None = None):
""" """
Validate and get API token. Validate and get API token.
""" """
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'") raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = ( current_time = datetime.now(UTC).replace(tzinfo=None)
db.session.query(ApiToken) cutoff_time = current_time - timedelta(minutes=1)
.filter( with Session(db.engine, expire_on_commit=False) as session:
ApiToken.token == auth_token, update_stmt = (
ApiToken.type == scope, update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
) )
.first() result = session.execute(update_stmt)
) api_token = result.scalar_one_or_none()
if not api_token: if not api_token:
raise Unauthorized("Access token is invalid") stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) if not api_token:
db.session.commit() raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token return api_token

View File

@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value

View File

@ -14,7 +14,11 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser

View File

@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from extensions.ext_database import db from extensions.ext_database import db
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -68,24 +68,17 @@ from models.account import Account
from models.enums import CreatedByRole from models.enums import CreatedByRole
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution,
WorkflowRunStatus, WorkflowRunStatus,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): class AdvancedChatAppGenerateTaskPipeline:
""" """
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
def __init__( def __init__(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream: bool, stream: bool,
dialogue_count: int, dialogue_count: int,
) -> None: ) -> None:
super().__init__( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=stream, stream=stream,
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else: else:
raise NotImplementedError(f"User type not supported: {type(user)}") raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_id = workflow.id self._workflow_cycle_manager = WorkflowCycleManage(
self._workflow_features_dict = workflow.features_dict application_generate_entity=application_generate_entity,
workflow_system_variables={
self._conversation_id = conversation.id SystemVariableKey.QUERY: message.query,
self._conversation_mode = conversation.mode SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
self._message_id = message.id SystemVariableKey.USER_ID: user_session_id,
self._message_created_at = int(message.created_at.timestamp()) SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
self._workflow_system_variables = { SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.QUERY: message.query, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
SystemVariableKey.FILES: application_generate_entity.files, },
SystemVariableKey.CONVERSATION_ID: conversation.id, )
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {} self._message_cycle_manager = MessageCycleManage(
self._wip_workflow_agent_logs = {} application_generate_entity=application_generate_entity, task_state=self._task_state
)
self._conversation_name_generate_thread = None self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = [] self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = "" self._workflow_run_id: str = ""
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
""" """
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:return: :return:
""" """
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query conversation_id=self._conversation_id, query=self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self._base_task_pipeline._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# init fake graph runtime state # init fake graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None graph_runtime_state: Optional[GraphRuntimeState] = None
for queue_message in self._queue_manager.listen(): for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event event = queue_message.event
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id) err = self._base_task_pipeline._handle_error(
event=event, session=session, message_id=self._message_id
)
session.commit() session.commit()
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start( workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id, user_id=self._user_id,
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not message: if not message:
raise ValueError(f"Message not found: {self._message_id}") raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_start_to_stream_response( workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_workflow_node_execution_retried( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_node_retry_to_stream_response( node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_node_execution_start( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_start_resp = self._workflow_node_start_to_stream_response( node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
# Record files if it's an answer node or end node # Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]: if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -368,10 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event
)
node_finish_resp = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -385,13 +392,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_start_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_start_resp yield parallel_start_resp
@ -399,13 +410,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_finish_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_finish_resp yield parallel_finish_resp
@ -413,9 +428,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_start_resp = self._workflow_iteration_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -427,9 +444,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_next_resp = self._workflow_iteration_next_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -441,9 +460,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_finish_resp = self._workflow_iteration_completed_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -458,8 +479,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -470,21 +491,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent): elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_partial_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -495,21 +518,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -521,20 +546,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
trace_manager=trace_manager, trace_manager=trace_manager,
exceptions_count=event.exceptions_count, exceptions_count=event.exceptions_count,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id) err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit() session.commit()
yield workflow_finish_resp yield workflow_finish_resp
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueStopEvent): elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state: if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -545,7 +572,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -559,18 +586,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._message_cycle_manager._handle_retriever_resources(event)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
session.commit() session.commit()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event) self._message_cycle_manager._handle_annotation_reply(event)
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@ -591,23 +618,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(queue_message) tts_publisher.publish(queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response( yield self._message_cycle_manager._message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer: if output_moderation_answer:
self._task_state.answer = output_moderation_answer self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer) yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer
)
# Save message # Save message
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state) self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit() session.commit()
@ -627,7 +658,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session) message = self._get_message(session=session)
message.answer = self._task_state.answer message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
@ -691,20 +722,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
:param text: text :param text: text
:return: True if output moderation should direct output, otherwise False :return: True if output moderation should direct output, otherwise False
""" """
if self._output_moderation_handler: if self._base_task_pipeline._output_moderation_handler:
if self._output_moderation_handler.should_direct_output(): if self._base_task_pipeline._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output() self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._queue_manager.publish( self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
) )
self._queue_manager.publish( self._base_task_pipeline._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
) )
return True return True
else: else:
self._output_moderation_handler.append_new_token(text) self._base_task_pipeline._output_moderation_handler.append_new_token(text)
return False return False

View File

@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"] node_id=node_id, inputs=args["inputs"]
), ),
workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union from typing import Optional, Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -59,7 +59,6 @@ from models.workflow import (
Workflow, Workflow,
WorkflowAppLog, WorkflowAppLog,
WorkflowAppLogCreatedFrom, WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
) )
@ -67,16 +66,11 @@ from models.workflow import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): class WorkflowAppGenerateTaskPipeline:
""" """
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def __init__( def __init__(
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
super().__init__( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=stream, stream=stream,
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else: else:
raise ValueError(f"Invalid user type: {type(user)}") raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict self._workflow_features_dict = workflow.features_dict
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._workflow_run_id = "" self._workflow_run_id = ""
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return: :return:
""" """
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self._base_task_pipeline._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
""" """
graph_runtime_state = None graph_runtime_state = None
for queue_message in self._queue_manager.listen(): for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event event = queue_message.event
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event=event) err = self._base_task_pipeline._handle_error(event=event)
yield self._error_to_stream_response(err) yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start( workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id, user_id=self._user_id,
created_by_role=self._created_by_role, created_by_role=self._created_by_role,
) )
self._workflow_run_id = workflow_run.id self._workflow_run_id = workflow_run.id
start_resp = self._workflow_start_to_stream_response( start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
): ):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_workflow_node_execution_retried( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
response = self._workflow_node_retry_to_stream_response( response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
workflow_node_execution = self._handle_node_execution_start( session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_start_response = self._workflow_node_start_to_stream_response( node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
node_success_response = self._workflow_node_finish_to_stream_response( session=session, event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, session=session,
event=event, event=event,
) )
node_failed_response = self._workflow_node_finish_to_stream_response( node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_start_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_start_resp yield parallel_start_resp
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
session=session, )
task_id=self._application_generate_entity.task_id, parallel_finish_resp = (
workflow_run=workflow_run, self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )
yield parallel_finish_resp yield parallel_finish_resp
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_start_resp = self._workflow_iteration_start_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_next_resp = self._workflow_iteration_next_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) workflow_run = self._workflow_cycle_manager._get_workflow_run(
iter_finish_resp = self._workflow_iteration_completed_to_stream_response( session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_partial_success( workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
with Session(db.engine) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()

View File

@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
workflow_run_id: Optional[str] = None workflow_run_id: str
class SingleIterationRunEntity(BaseModel): class SingleIterationRunEntity(BaseModel):
""" """

View File

@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
ErrorStreamResponse, ErrorStreamResponse,
PingStreamResponse, PingStreamResponse,
TaskState,
) )
from core.errors.error import QuotaExceededError from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__( def __init__(
self, self,
application_generate_entity: AppGenerateEntity, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
) -> None: ) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._start_at = time.perf_counter() self._start_at = time.perf_counter()

View File

@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage: class MessageCycleManage:
_application_generate_entity: Union[ def __init__(
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity self,
] *,
_task_state: Union[EasyUITaskState, WorkflowTaskState] application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
""" """

View File

@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
ParallelBranchStartStreamResponse, ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
WorkflowTaskState,
) )
from core.file import FILE_MODEL_IDENTITY, File from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -60,13 +59,20 @@ from models.workflow import (
WorkflowRunStatus, WorkflowRunStatus,
) )
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError from .exc import WorkflowRunNotFoundError
class WorkflowCycleManage: class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] def __init__(
_task_state: WorkflowTaskState self,
_workflow_system_variables: dict[SystemVariableKey, Any] *,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,
@ -104,7 +110,8 @@ class WorkflowCycleManage:
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run # init workflow run
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
workflow_run = WorkflowRun() workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id workflow_run.id = workflow_run_id
@ -241,7 +248,7 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution).where( stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
@ -249,15 +256,18 @@ class WorkflowCycleManage:
WorkflowNodeExecution.workflow_run_id == workflow_run.id, WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
) )
ids = session.scalars(stmt).all()
running_workflow_node_executions = session.scalars(stmt).all() # Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
finish_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = now
workflow_node_execution.finished_at = finish_at workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -299,6 +309,8 @@ class WorkflowCycleManage:
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success( def _handle_workflow_node_execution_success(
@ -325,6 +337,7 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
@ -364,6 +377,7 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
@ -415,6 +429,8 @@ class WorkflowCycleManage:
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution) session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
################################################# #################################################
@ -811,25 +827,23 @@ class WorkflowCycleManage:
return None return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
""" if self._workflow_run and self._workflow_run.id == workflow_run_id:
Refetch workflow run cached_workflow_run = self._workflow_run
:param workflow_run_id: workflow run id cached_workflow_run = session.merge(cached_workflow_run)
:return: return cached_workflow_run
"""
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt) workflow_run = session.scalar(stmt)
if not workflow_run: if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id) raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) if node_execution_id not in self._workflow_node_executions:
workflow_node_execution = session.scalar(stmt) raise ValueError(f"Workflow node execution not found: {node_execution_id}")
if not workflow_node_execution: cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
raise WorkflowNodeExecutionNotFoundError(node_execution_id) return cached_workflow_node_execution
return workflow_node_execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
""" """

View File

@ -1,9 +1,47 @@
import tiktoken from threading import Lock
from typing import Any
_tokenizer: Any = None
_lock = Lock()
class GPT2Tokenizer: class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text)
return len(tokens)
@staticmethod @staticmethod
def get_num_tokens(text: str) -> int: def get_num_tokens(text: str) -> int:
encoding = tiktoken.encoding_for_model("gpt2") # Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
tiktoken_vec = encoding.encode(text) #
return len(tiktoken_vec) # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
#
try:
import tiktoken
_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer

View File

@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
return False return False
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids] quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")

View File

@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
self._client.delete_collection(self._collection_name) self._client.delete_collection(self._collection_name)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name) collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(ids=ids) collection.delete(ids=ids)

View File

@ -0,0 +1,104 @@
import json
import logging
from typing import Any, Optional
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
return bool(self._client.exists(index=self._collection_name, id=id)) return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
for id in ids: for id in ids:
self._client.delete(index=self._collection_name, id=id) self._client.delete(index=self._collection_name, id=id)

View File

@ -6,6 +6,8 @@ class Field(Enum):
METADATA_KEY = "metadata" METADATA_KEY = "metadata"
GROUP_KEY = "group_id" GROUP_KEY = "group_id"
VECTOR = "vector" VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text" TEXT_KEY = "text"
PRIMARY_KEY = "id" PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id" DOC_ID = "metadata.doc_id"

View File

@ -2,6 +2,7 @@ import json
import logging import logging
from typing import Any, Optional from typing import Any, Optional
from packaging import version
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore from pymilvus.milvus_client import IndexParams # type: ignore
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel): class MilvusConfig(BaseModel):
uri: str """
token: Optional[str] = None Configuration class for Milvus connection.
user: str """
password: str
batch_size: int = 100 uri: str # Milvus server URI
database: str = "default" token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"): if not values.get("uri"):
raise ValueError("config MILVUS_URI is required") raise ValueError("config MILVUS_URI is required")
if not values.get("user"): if not values.get("user"):
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
return values return values
def to_milvus_params(self): def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return { return {
"uri": self.uri, "uri": self.uri,
"token": self.token, "token": self.token,
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
class MilvusVector(BaseVector): class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""
def __init__(self, collection_name: str, config: MilvusConfig): def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name) super().__init__(collection_name)
self._client_config = config self._client_config = config
self._client = self._init_client(config) self._client = self._init_client(config)
self._consistency_level = "Session" self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False
try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False
def get_type(self) -> str: def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts] metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params) self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings) self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = [] insert_dict_list = []
for i in range(len(documents)): for i in range(len(documents)):
insert_dict = { insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content, Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata, Field.METADATA_KEY.value: documents[i].metadata,
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
insert_dict_list.append(insert_dict) insert_dict_list.append(insert_dict)
# Total insert count # Total insert count
total_count = len(insert_dict_list) total_count = len(insert_dict_list)
pks: list[str] = [] pks: list[str] = []
for i in range(0, total_count, 1000): for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection. # Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try: try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids) pks.extend(ids)
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
return pks return pks
def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query( result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
) )
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
return None return None
def delete_by_metadata_field(self, key: str, value: str): def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value) ids = self.get_ids_by_metadata_field(key, value)
if ids: if ids:
self._client.delete(collection_name=self._collection_name, pks=ids) self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
result = self._client.query( result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=ids) self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None: def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None) self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name): if not self._client.has_collection(self._collection_name):
return False return False
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
return len(result) > 0 return len(result) > 0
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields
def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)
return docs
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters. """
Search for documents by vector similarity.
"""
results = self._client.search( results = self._client.search(
collection_name=self._collection_name, collection_name=self._collection_name,
data=[query_vector], data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
) )
# Organize results.
docs = [] return self._process_search_results(
for result in results[0]: results,
metadata = result["entity"].get(Field.METADATA_KEY.value) output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
metadata["score"] = result["distance"] score_threshold=float(kwargs.get("score_threshold") or 0.0),
score_threshold = float(kwargs.get("score_threshold") or 0.0) )
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search """
return [] Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def create_collection( def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
): ):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name) lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20): with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
return return
# Grab the existing collection if it exists # Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name): if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim # Determine embedding dim
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
if metadatas: if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field # Create the text field, enable_analyzer will be set True to support milvus automatically
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the primary key field # Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors # Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
# Create the schema for the collection
schema = CollectionSchema(fields) schema = CollectionSchema(fields)
# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
for x in schema.fields: for x in schema.fields:
self._fields.append(x.name) self._fields.append(x.name)
# Since primary field is auto-id, no need to track it # Since primary field is auto-id, no need to track it
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
index_params_obj = IndexParams() index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection # Create the collection
collection_name = self._collection_name
self._client.create_collection( self._client.create_collection(
collection_name=collection_name, collection_name=self._collection_name,
schema=schema, schema=schema,
index_params=index_params_obj, index_params=index_params_obj,
consistency_level=self._consistency_level, consistency_level=self._consistency_level,
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient: def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client return client
class MilvusVectorFactory(AbstractVectorFactory): class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict: if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix collection_name = class_prefix
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER or "", user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "", password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "", database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
), ),
) )

View File

@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
return results.row_count > 0 return results.row_count > 0
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.command( self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
) )

View File

@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
return bool(cur.rowcount != 0) return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids) self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:

View File

@ -167,6 +167,8 @@ class OracleVector(BaseVector):
return docs return docs
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))

View File

@ -129,6 +129,11 @@ class PGVector(BaseVector):
return docs return docs
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))

View File

@ -140,6 +140,8 @@ class TencentVector(BaseVector):
return False return False
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._db.collection(self._collection_name).delete(document_ids=ids) self._db.collection(self._collection_name).delete(document_ids=ids)
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:

View File

@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
) )
if not tidb_auth_binding: if not tidb_auth_binding:
idle_tidb_auth_binding = ( with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
db.session.query(TidbAuthBinding) tidb_auth_binding = (
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") db.session.query(TidbAuthBinding)
.limit(1) .filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none() .one_or_none()
) )
if idle_tidb_auth_binding: if tidb_auth_binding:
idle_tidb_auth_binding.active = True TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit() else:
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" idle_tidb_auth_binding = (
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding) db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id) .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none() .one_or_none()
) )
if tidb_auth_binding: if idle_tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else: else:
new_cluster = TidbService.create_tidb_serverless_cluster( new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "", dify_config.TIDB_PROJECT_ID or "",
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
db.session.add(new_tidb_auth_binding) db.session.add(new_tidb_auth_binding)
db.session.commit() db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else: else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@ -90,6 +90,12 @@ class Vector:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR: case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory

View File

@ -16,6 +16,7 @@ class VectorType(StrEnum):
TENCENT = "tencent" TENCENT = "tencent"
ORACLE = "oracle" ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch" ELASTICSEARCH = "elasticsearch"
ELASTICSEARCH_JA = "elasticsearch-ja"
LINDORM = "lindorm" LINDORM = "lindorm"
COUCHBASE = "couchbase" COUCHBASE = "couchbase"
BAIDU = "baidu" BAIDU = "baidu"

View File

@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
self._file_cache_key = file_cache_key self._file_cache_key = file_cache_key
def extract(self) -> list[Document]: def extract(self) -> list[Document]:
plaintext_file_key = ""
plaintext_file_exists = False plaintext_file_exists = False
if self._file_cache_key: if self._file_cache_key:
try: try:
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
text = "\n\n".join(text_list) text = "\n\n".join(text_list)
# save plaintext file for caching # save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key: if not plaintext_file_exists and self._file_cache_key:
storage.save(plaintext_file_key, text.encode("utf-8")) storage.save(self._file_cache_key, text.encode("utf-8"))
return documents return documents

View File

@ -3,6 +3,7 @@
import uuid import uuid
from typing import Optional from typing import Optional
from configs import dify_config
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_nodes = self._split_child_nodes( child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
) )
if kwargs.get("preview"):
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
document.children = child_nodes document.children = child_nodes
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content) hash = helper.generate_text_hash(document.page_content)

View File

@ -212,8 +212,23 @@ class ApiTool(Tool):
else: else:
body = body body = body
if method in {"get", "head", "post", "put", "delete", "patch"}: if method in {
response: httpx.Response = getattr(ssrf_proxy, method)( "get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
url, url,
params=params, params=params,
headers=headers, headers=headers,

View File

@ -2,14 +2,18 @@ import csv
import io import io
import json import json
import logging import logging
import operator
import os import os
import tempfile import tempfile
from typing import cast from collections.abc import Mapping, Sequence
from typing import Any, cast
import docx import docx
import pandas as pd import pandas as pd
import pypdfium2 # type: ignore import pypdfium2 # type: ignore
import yaml # type: ignore import yaml # type: ignore
from docx.table import Table
from docx.text.paragraph import Paragraph
from configs import dify_config from configs import dify_config
from core.file import File, FileTransferMethod, file_manager from core.file import File, FileTransferMethod, file_manager
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
process_data=process_data, process_data=process_data,
) )
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
"""Extract text from a file based on its MIME type.""" """Extract text from a file based on its MIME type."""
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
doc_file = io.BytesIO(file_content) doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file) doc = docx.Document(doc_file)
text = [] text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables # Keep track of paragraph and table positions
for table in doc.tables: content_items: list[tuple[int, str, Table | Paragraph]] = []
# Table header
try: # Process paragraphs and tables
# table maybe cause errors so ignore it. for i, paragraph in enumerate(doc.paragraphs):
if len(table.rows) > 0 and table.rows[0].cells is not None: if paragraph.text.strip():
content_items.append((i, "paragraph", paragraph))
for i, table in enumerate(doc.tables):
content_items.append((i, "table", table))
# Sort content items based on their original position
content_items.sort(key=operator.itemgetter(0))
# Process sorted content
for _, item_type, item in content_items:
if item_type == "paragraph":
if isinstance(item, Table):
continue
text.append(item.text)
elif item_type == "table":
# Process tables
if not isinstance(item, Table):
continue
try:
# Check if any cell in the table has text # Check if any cell in the table has text
has_content = False has_content = False
for row in table.rows: for row in item.rows:
if any(cell.text.strip() for cell in row.cells): if any(cell.text.strip() for cell in row.cells):
has_content = True has_content = True
break break
if has_content: if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" markdown_table = f"| {' | '.join(cell_texts)} |\n"
for row in table.rows[1:]: markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
for row in item.rows[1:]:
# Replace newlines with <br> in each cell
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
markdown_table += "| " + " | ".join(row_cells) + " |\n"
text.append(markdown_table) text.append(markdown_table)
except Exception as e: except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}") logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue continue
return "\n".join(text) return "\n".join(text)
except Exception as e: except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e

View File

@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data. Code Node Data.
""" """
method: Literal["get", "post", "put", "patch", "delete", "head"] method: Literal[
"get",
"post",
"put",
"patch",
"delete",
"head",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str url: str
authorization: HttpRequestNodeAuthorization authorization: HttpRequestNodeAuthorization
headers: str headers: str

View File

@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
class Executor: class Executor:
method: Literal["get", "head", "post", "put", "delete", "patch"] method: Literal[
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str url: str
params: list[tuple[str, str]] | None params: list[tuple[str, str]] | None
content: str | bytes | None content: str | bytes | None
@ -67,12 +82,6 @@ class Executor:
node_data.authorization.config.api_key node_data.authorization.config.api_key
).text ).text
# check if node_data.url is a valid URL
if not node_data.url:
raise InvalidURLError("url is required")
if not node_data.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
self.url: str = node_data.url self.url: str = node_data.url
self.method = node_data.method self.method = node_data.method
self.auth = node_data.authorization self.auth = node_data.authorization
@ -99,6 +108,12 @@ class Executor:
def _init_url(self): def _init_url(self):
self.url = self.variable_pool.convert_template(self.node_data.url).text self.url = self.variable_pool.convert_template(self.node_data.url).text
# check if url is a valid URL
if not self.url:
raise InvalidURLError("url is required")
if not self.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
def _init_params(self): def _init_params(self):
""" """
Almost same as _init_headers(), difference: Almost same as _init_headers(), difference:
@ -158,7 +173,10 @@ class Executor:
if len(data) != 1: if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item") raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string, strict=False) try:
json_object = json.loads(json_string, strict=False)
except json.JSONDecodeError as e:
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
self.json = json_object self.json = json_object
# self.json = self._parse_object_contains_variables(json_object) # self.json = self._parse_object_contains_variables(json_object)
case "binary": case "binary":
@ -246,7 +264,22 @@ class Executor:
""" """
do http request depending on api bundle do http request depending on api bundle
""" """
if self.method not in {"get", "head", "post", "put", "delete", "patch"}: if self.method not in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
raise InvalidHttpMethodError(f"Invalid http method {self.method}") raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = { request_args = {
@ -263,7 +296,7 @@ class Executor:
} }
# request_args = {k: v for k, v in request_args.items() if v is not None} # request_args = {k: v for k, v in request_args.items() if v is not None}
try: try:
response = getattr(ssrf_proxy, self.method)(**request_args) response = getattr(ssrf_proxy, self.method.lower())(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e)) raise HttpRequestNodeError(str(e))
# FIXME: fix type ignore, this maybe httpx type issue # FIXME: fix type ignore, this maybe httpx type issue

View File

@ -340,6 +340,10 @@ class WorkflowEntry:
): ):
raise ValueError(f"Variable key {node_variable} not found in user inputs.") raise ValueError(f"Variable key {node_variable} not found in user inputs.")
# environment variable already exist in variable pool, not from user inputs
if variable_pool.get(variable_selector):
continue
# fetch variable node id from variable selector # fetch variable node id from variable selector
variable_node_id = variable_selector[0] variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:] variable_key_list = variable_selector[1:]

View File

@ -33,6 +33,7 @@ else
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \ --workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \ --worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \ --timeout ${GUNICORN_TIMEOUT:-200} \
app:app app:app
fi fi

View File

@ -46,7 +46,7 @@ def init_app(app: DifyApp):
timezone = pytz.timezone(log_tz) timezone = pytz.timezone(log_tz)
def time_converter(seconds): def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers: for handler in logging.root.handlers:
if handler.formatter: if handler.formatter:

View File

@ -158,7 +158,7 @@ def _build_from_remote_url(
tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
) -> File: ) -> File:
url = mapping.get("url") url = mapping.get("url") or mapping.get("remote_url")
if not url: if not url:
raise ValueError("Invalid file url") raise ValueError("Invalid file url")

View File

@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource):
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise ValueError(f"Error fetching block parent page ID: {response_json.message}") message = response_json.get("message", "unknown error")
raise ValueError(f"Error fetching block parent page ID: {message}")
parent = response_json["parent"] parent = response_json["parent"]
parent_type = parent["type"] parent_type = parent["type"]
if parent_type == "block_id": if parent_type == "block_id":

View File

@ -0,0 +1,41 @@
"""change workflow_runs.total_tokens to bigint
Revision ID: a91b476a53de
Revises: 923752d42eb6
Create Date: 2025-01-01 20:00:01.207369
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a91b476a53de'
down_revision = '923752d42eb6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.INTEGER(),
type_=sa.BigInteger(),
existing_nullable=False,
existing_server_default=sa.text('0'))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.BigInteger(),
type_=sa.INTEGER(),
existing_nullable=False,
existing_server_default=sa.text('0'))
# ### end Alembic commands ###

View File

@ -415,8 +415,8 @@ class WorkflowRun(Base):
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text) error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)

2397
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -71,7 +71,7 @@ pyjwt = "~2.8.0"
pypdfium2 = "~4.30.0" pypdfium2 = "~4.30.0"
python = ">=3.11,<3.13" python = ">=3.11,<3.13"
python-docx = "~1.1.0" python-docx = "~1.1.0"
python-dotenv = "1.0.0" python-dotenv = "1.0.1"
pyyaml = "~6.0.1" pyyaml = "~6.0.1"
readabilipy = "0.2.0" readabilipy = "0.2.0"
redis = { version = "~5.0.3", extras = ["hiredis"] } redis = { version = "~5.0.3", extras = ["hiredis"] }
@ -82,7 +82,7 @@ scikit-learn = "~1.5.1"
sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
sqlalchemy = "~2.0.29" sqlalchemy = "~2.0.29"
starlette = "0.41.0" starlette = "0.41.0"
tencentcloud-sdk-python-hunyuan = "~3.0.1158" tencentcloud-sdk-python-hunyuan = "~3.0.1294"
tiktoken = "~0.8.0" tiktoken = "~0.8.0"
tokenizers = "~0.15.0" tokenizers = "~0.15.0"
transformers = "~4.35.0" transformers = "~4.35.0"
@ -92,7 +92,7 @@ validators = "0.21.0"
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
websocket-client = "~1.7.0" websocket-client = "~1.7.0"
xinference-client = "0.15.2" xinference-client = "0.15.2"
yarl = "~1.9.4" yarl = "~1.18.3"
youtube-transcript-api = "~0.6.2" youtube-transcript-api = "~0.6.2"
zhipuai = "~2.1.5" zhipuai = "~2.1.5"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
@ -157,7 +157,7 @@ opensearch-py = "2.4.0"
oracledb = "~2.2.1" oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5" pgvector = "0.2.5"
pymilvus = "~2.4.4" pymilvus = "~2.5.0"
pymochow = "1.3.1" pymochow = "1.3.1"
pyobvector = "~0.1.6" pyobvector = "~0.1.6"
qdrant-client = "1.7.3" qdrant-client = "1.7.3"

View File

@ -168,23 +168,6 @@ def clean_unused_datasets_task():
else: else:
plan = plan_cache.decode() plan = plan_cache.decode()
if plan == "sandbox": if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index # remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None)

View File

@ -66,7 +66,7 @@ class TokenPair(BaseModel):
REFRESH_TOKEN_PREFIX = "refresh_token:" REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30) REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService: class AccountService:

View File

@ -2,6 +2,7 @@ import logging
import uuid import uuid
from enum import StrEnum from enum import StrEnum
from typing import Optional, cast from typing import Optional, cast
from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
import yaml # type: ignore import yaml # type: ignore
@ -124,7 +125,7 @@ class AppDslService:
raise ValueError(f"Invalid import_mode: {import_mode}") raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content # Get YAML content
content: bytes | str = b"" content: str = ""
if mode == ImportMode.YAML_URL: if mode == ImportMode.YAML_URL:
if not yaml_url: if not yaml_url:
return Import( return Import(
@ -133,13 +134,17 @@ class AppDslService:
error="yaml_url is required when import_mode is yaml-url", error="yaml_url is required when import_mode is yaml-url",
) )
try: try:
# tricky way to handle url from github to github raw url parsed_url = urlparse(yaml_url)
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")): if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/") yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status() response.raise_for_status()
content = response.content content = response.content.decode()
if len(content) > DSL_MAX_SIZE: if len(content) > DSL_MAX_SIZE:
return Import( return Import(

View File

@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
class AppService: class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
""" """
Get app list with pagination Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id :param tenant_id: tenant id
:param args: request args :param args: request args
:return: :return:
@ -44,6 +45,8 @@ class AppService:
elif args["mode"] == "channel": elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value) filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"): if args.get("name"):
name = args["name"][:30] name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%")) filters.append(App.name.ilike(f"%{name}%"))

View File

@ -1,5 +1,5 @@
import os import os
from typing import Optional from typing import Literal, Optional
import httpx import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@ -17,7 +17,6 @@ class BillingService:
params = {"tenant_id": tenant_id} params = {"tenant_id": tenant_id}
billing_info = cls._send_request("GET", "/subscription/info", params=params) billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info return billing_info
@classmethod @classmethod
@ -47,12 +46,13 @@ class BillingService:
retry=retry_if_exception_type(httpx.RequestError), retry=retry_if_exception_type(httpx.RequestError),
reraise=True, reraise=True,
) )
def _send_request(cls, method, endpoint, json=None, params=None): def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers) response = httpx.request(method, url, json=json, params=params, headers=headers)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json() return response.json()
@staticmethod @staticmethod

View File

@ -86,7 +86,7 @@ class DatasetService:
else: else:
return [], 0 return [], 0
else: else:
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): if user.current_role != TenantAccountRole.OWNER:
# show all datasets that the user has permission to access # show all datasets that the user has permission to access
if permitted_dataset_ids: if permitted_dataset_ids:
query = query.filter( query = query.filter(
@ -382,7 +382,7 @@ class DatasetService:
if dataset.tenant_id != user.current_tenant_id: if dataset.tenant_id != user.current_tenant_id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@ -404,7 +404,7 @@ class DatasetService:
if not user: if not user:
raise ValueError("User not found") raise ValueError("User not found")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id: if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@ -434,6 +434,12 @@ class DatasetService:
@staticmethod @staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict: def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
"document_ids": [],
"count": 0,
}
# get recent 30 days auto disable logs # get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30) start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
@ -786,13 +792,19 @@ class DocumentService:
dataset.indexing_technique = knowledge_config.indexing_technique dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager() model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance( if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING dataset_embedding_model = knowledge_config.embedding_model
) dataset_embedding_model_provider = knowledge_config.embedding_model_provider
dataset.embedding_model = embedding_model.model else:
dataset.embedding_model_provider = embedding_model.provider embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model dataset_embedding_model_provider, dataset_embedding_model
) )
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model: if not dataset.retrieval_model:
@ -804,7 +816,11 @@ class DocumentService:
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = [] documents = []
if knowledge_config.original_document_id: if knowledge_config.original_document_id:

View File

@ -27,7 +27,7 @@ class WorkflowAppService:
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
if keyword: if keyword:
keyword_like_val = f"%{args['keyword'][:30]}%" keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
keyword_conditions = [ keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val), WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val), WorkflowRun.outputs.ilike(keyword_like_val),

View File

@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if not dataset: if not dataset:
raise Exception("Dataset not found") raise Exception("Dataset not found")
index_type = dataset.doc_form index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove": if action == "remove":
index_processor.clean(dataset, None, with_keywords=False) index_processor.clean(dataset, None, with_keywords=False)
@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
{"indexing_status": "error", "error": str(e)}, synchronize_session=False {"indexing_status": "error", "error": str(e)}, synchronize_session=False
) )
db.session.commit() db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(

View File

@ -0,0 +1,55 @@
import os
from pathlib import Path
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
def test_validate_credentials():
model = GPUStackSpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_model():
model = GPUStackSpeech2TextModel()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")
file = Path(audio_file_path).read_bytes()
result = model.invoke(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
file=file,
)
assert isinstance(result, str)
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@ -0,0 +1,24 @@
import os
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
def test_invoke_model():
model = GPUStackText2SpeechModel()
result = model.invoke(
model="cosyvoice-300m-sft",
tenant_id="test",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
content_text="Hello world",
voice="Chinese Female",
)
content = b""
for chunk in result:
content += chunk
assert content != b""

View File

@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest):
) )
def search_by_full_text(self): def search_by_full_text(self):
# milvus dos not support full text searching yet in < 2.3.x # milvus support BM25 full text search after version 2.5.0-beta
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0 assert len(hits_by_full_text) >= 0
def get_ids_by_metadata_field(self): def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)

View File

@ -2,7 +2,7 @@ version: '3'
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:0.14.2 image: langgenius/dify-api:0.15.0
restart: always restart: always
environment: environment:
# Startup mode, 'api' starts the API server. # Startup mode, 'api' starts the API server.
@ -227,7 +227,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:0.14.2 image: langgenius/dify-api:0.15.0
restart: always restart: always
environment: environment:
CONSOLE_WEB_URL: '' CONSOLE_WEB_URL: ''
@ -397,7 +397,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:0.14.2 image: langgenius/dify-web:0.15.0
restart: always restart: always
environment: environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

View File

@ -105,6 +105,9 @@ FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes # Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60 ACCESS_TOKEN_EXPIRE_MINUTES=60
# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
@ -123,10 +126,13 @@ DIFY_PORT=5001
# The number of API server workers, i.e., the number of workers. # The number of API server workers, i.e., the number of workers.
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent # Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
SERVER_WORKER_AMOUNT= SERVER_WORKER_AMOUNT=1
# Defaults to gevent. If using windows, it can be switched to sync or solo. # Defaults to gevent. If using windows, it can be switched to sync or solo.
SERVER_WORKER_CLASS= SERVER_WORKER_CLASS=gevent
# Default number of worker connections, the default is 10.
SERVER_WORKER_CONNECTIONS=10
# Similar to SERVER_WORKER_CLASS. # Similar to SERVER_WORKER_CLASS.
# If using windows, it can be switched to sync or solo. # If using windows, it can be switched to sync or solo.
@ -377,7 +383,7 @@ SUPABASE_URL=your-server-url
# ------------------------------ # ------------------------------
# The type of vector store to use. # The type of vector store to use.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
@ -397,6 +403,7 @@ MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN= MILVUS_TOKEN=
MILVUS_USER=root MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
MILVUS_ENABLE_HYBRID_SEARCH=False
# MyScale configuration, only available when VECTOR_STORE is `myscale` # MyScale configuration, only available when VECTOR_STORE is `myscale`
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: # For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
@ -506,7 +513,7 @@ TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2 TENCENT_VECTOR_DB_REPLICAS=2
# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` # ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch`
ELASTICSEARCH_HOST=0.0.0.0 ELASTICSEARCH_HOST=elasticsearch
ELASTICSEARCH_PORT=9200 ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic ELASTICSEARCH_PASSWORD=elastic
@ -923,6 +930,9 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution # Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100 MAX_SUBMIT_COUNT=100
# The maximum number of top-k value for RAG.
TOP_K_MAX_VALUE=10
# ------------------------------ # ------------------------------
# Plugin Daemon Configuration # Plugin Daemon Configuration
# ------------------------------ # ------------------------------
@ -947,3 +957,4 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
MARKETPLACE_ENABLED=true MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev

View File

@ -73,6 +73,7 @@ services:
CSP_WHITELIST: ${CSP_WHITELIST:-} CSP_WHITELIST: ${CSP_WHITELIST:-}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
# The postgres database. # The postgres database.
db: db:
@ -92,7 +93,7 @@ services:
volumes: volumes:
- ./volumes/db/data:/var/lib/postgresql/data - ./volumes/db/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready'] test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
@ -111,7 +112,7 @@ services:
# Set the redis password when startup redis server. # Set the redis password when startup redis server.
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
healthcheck: healthcheck:
test: ['CMD', 'redis-cli', 'ping'] test: [ 'CMD', 'redis-cli', 'ping' ]
# The DifySandbox # The DifySandbox
sandbox: sandbox:
@ -131,7 +132,7 @@ services:
volumes: volumes:
- ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/dependencies:/dependencies
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
networks: networks:
- ssrf_proxy_network - ssrf_proxy_network
@ -167,12 +168,7 @@ services:
volumes: volumes:
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
entrypoint: entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
environment: environment:
# pls clearly modify the squid env vars to fit your network environment. # pls clearly modify the squid env vars to fit your network environment.
HTTP_PORT: ${SSRF_HTTP_PORT:-3128} HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
@ -201,8 +197,8 @@ services:
- CERTBOT_EMAIL=${CERTBOT_EMAIL} - CERTBOT_EMAIL=${CERTBOT_EMAIL}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN} - CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
entrypoint: ['/docker-entrypoint.sh'] entrypoint: [ '/docker-entrypoint.sh' ]
command: ['tail', '-f', '/dev/null'] command: [ 'tail', '-f', '/dev/null' ]
# The nginx reverse proxy. # The nginx reverse proxy.
# used for reverse proxying the API service and Web service. # used for reverse proxying the API service and Web service.
@ -219,12 +215,7 @@ services:
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
- ./volumes/certbot/conf:/etc/letsencrypt - ./volumes/certbot/conf:/etc/letsencrypt
- ./volumes/certbot/www:/var/www/html - ./volumes/certbot/www:/var/www/html
entrypoint: entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
environment: environment:
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
@ -316,7 +307,7 @@ services:
working_dir: /opt/couchbase working_dir: /opt/couchbase
stdin_open: true stdin_open: true
tty: true tty: true
entrypoint: [""] entrypoint: [ "" ]
command: sh -c "/opt/couchbase/init/init-cbserver.sh" command: sh -c "/opt/couchbase/init/init-cbserver.sh"
volumes: volumes:
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
@ -345,7 +336,7 @@ services:
volumes: volumes:
- ./volumes/pgvector/data:/var/lib/postgresql/data - ./volumes/pgvector/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready'] test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
@ -367,7 +358,7 @@ services:
volumes: volumes:
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready'] test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
@ -432,7 +423,7 @@ services:
- ./volumes/milvus/etcd:/etcd - ./volumes/milvus/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
healthcheck: healthcheck:
test: ['CMD', 'etcdctl', 'endpoint', 'health'] test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
interval: 30s interval: 30s
timeout: 20s timeout: 20s
retries: 3 retries: 3
@ -451,7 +442,7 @@ services:
- ./volumes/milvus/minio:/minio_data - ./volumes/milvus/minio:/minio_data
command: minio server /minio_data --console-address ":9001" command: minio server /minio_data --console-address ":9001"
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
interval: 30s interval: 30s
timeout: 20s timeout: 20s
retries: 3 retries: 3
@ -463,7 +454,7 @@ services:
image: milvusdb/milvus:v2.3.1 image: milvusdb/milvus:v2.3.1
profiles: profiles:
- milvus - milvus
command: ['milvus', 'run', 'standalone'] command: [ 'milvus', 'run', 'standalone' ]
environment: environment:
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
@ -471,7 +462,7 @@ services:
volumes: volumes:
- ./volumes/milvus/milvus:/var/lib/milvus - ./volumes/milvus/milvus:/var/lib/milvus
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
interval: 30s interval: 30s
start_period: 90s start_period: 90s
timeout: 20s timeout: 20s
@ -559,7 +550,7 @@ services:
ports: ports:
- ${ELASTICSEARCH_PORT:-9200}:9200 - ${ELASTICSEARCH_PORT:-9200}:9200
healthcheck: healthcheck:
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 50 retries: 50
@ -587,7 +578,7 @@ services:
ports: ports:
- ${KIBANA_PORT:-5601}:5601 - ${KIBANA_PORT:-5601}:5601
healthcheck: healthcheck:
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3

View File

@ -27,12 +27,14 @@ x-shared-env: &shared-api-worker-env
MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true}
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30}
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
DIFY_PORT: ${DIFY_PORT:-5001} DIFY_PORT: ${DIFY_PORT:-5001}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-} SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent}
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-} CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
@ -136,6 +138,7 @@ x-shared-env: &shared-api-worker-env
MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_TOKEN: ${MILVUS_TOKEN:-}
MILVUS_USER: ${MILVUS_USER:-root} MILVUS_USER: ${MILVUS_USER:-root}
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus}
MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False}
MYSCALE_HOST: ${MYSCALE_HOST:-myscale} MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
MYSCALE_PORT: ${MYSCALE_PORT:-8123} MYSCALE_PORT: ${MYSCALE_PORT:-8123}
MYSCALE_USER: ${MYSCALE_USER:-default} MYSCALE_USER: ${MYSCALE_USER:-default}
@ -401,6 +404,7 @@ x-shared-env: &shared-api-worker-env
ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}} ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}}
MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10}
services: services:
# API service # API service
@ -476,6 +480,7 @@ services:
CSP_WHITELIST: ${CSP_WHITELIST:-} CSP_WHITELIST: ${CSP_WHITELIST:-}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
# The postgres database. # The postgres database.
db: db:
@ -495,7 +500,7 @@ services:
volumes: volumes:
- ./volumes/db/data:/var/lib/postgresql/data - ./volumes/db/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready'] test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
@ -514,7 +519,7 @@ services:
# Set the redis password when startup redis server. # Set the redis password when startup redis server.
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
healthcheck: healthcheck:
test: ['CMD', 'redis-cli', 'ping'] test: [ 'CMD', 'redis-cli', 'ping' ]
# The DifySandbox # The DifySandbox
sandbox: sandbox:
@ -534,7 +539,7 @@ services:
volumes: volumes:
- ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/dependencies:/dependencies
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
networks: networks:
- ssrf_proxy_network - ssrf_proxy_network
@ -571,12 +576,7 @@ services:
volumes: volumes:
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
entrypoint: entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
environment: environment:
# pls clearly modify the squid env vars to fit your network environment. # pls clearly modify the squid env vars to fit your network environment.
HTTP_PORT: ${SSRF_HTTP_PORT:-3128} HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
@ -605,8 +605,8 @@ services:
- CERTBOT_EMAIL=${CERTBOT_EMAIL} - CERTBOT_EMAIL=${CERTBOT_EMAIL}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN} - CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
entrypoint: ['/docker-entrypoint.sh'] entrypoint: [ '/docker-entrypoint.sh' ]
command: ['tail', '-f', '/dev/null'] command: [ 'tail', '-f', '/dev/null' ]
# The nginx reverse proxy. # The nginx reverse proxy.
# used for reverse proxying the API service and Web service. # used for reverse proxying the API service and Web service.
@ -623,12 +623,7 @@ services:
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
- ./volumes/certbot/conf:/etc/letsencrypt - ./volumes/certbot/conf:/etc/letsencrypt
- ./volumes/certbot/www:/var/www/html - ./volumes/certbot/www:/var/www/html
entrypoint: entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
environment: environment:
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
@ -720,7 +715,7 @@ services:
working_dir: /opt/couchbase working_dir: /opt/couchbase
stdin_open: true stdin_open: true
tty: true tty: true
entrypoint: [""] entrypoint: [ "" ]
command: sh -c "/opt/couchbase/init/init-cbserver.sh" command: sh -c "/opt/couchbase/init/init-cbserver.sh"
volumes: volumes:
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
@ -749,7 +744,7 @@ services:
volumes: volumes:
- ./volumes/pgvector/data:/var/lib/postgresql/data - ./volumes/pgvector/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready'] test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
@ -771,7 +766,7 @@ services:
volumes: volumes:
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready'] test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
@ -836,7 +831,7 @@ services:
- ./volumes/milvus/etcd:/etcd - ./volumes/milvus/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
healthcheck: healthcheck:
test: ['CMD', 'etcdctl', 'endpoint', 'health'] test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
interval: 30s interval: 30s
timeout: 20s timeout: 20s
retries: 3 retries: 3
@ -855,7 +850,7 @@ services:
- ./volumes/milvus/minio:/minio_data - ./volumes/milvus/minio:/minio_data
command: minio server /minio_data --console-address ":9001" command: minio server /minio_data --console-address ":9001"
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
interval: 30s interval: 30s
timeout: 20s timeout: 20s
retries: 3 retries: 3
@ -864,10 +859,10 @@ services:
milvus-standalone: milvus-standalone:
container_name: milvus-standalone container_name: milvus-standalone
image: milvusdb/milvus:v2.3.1 image: milvusdb/milvus:v2.5.0-beta
profiles: profiles:
- milvus - milvus
command: ['milvus', 'run', 'standalone'] command: [ 'milvus', 'run', 'standalone' ]
environment: environment:
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
@ -875,7 +870,7 @@ services:
volumes: volumes:
- ./volumes/milvus/milvus:/var/lib/milvus - ./volumes/milvus/milvus:/var/lib/milvus
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
interval: 30s interval: 30s
start_period: 90s start_period: 90s
timeout: 20s timeout: 20s
@ -948,22 +943,30 @@ services:
container_name: elasticsearch container_name: elasticsearch
profiles: profiles:
- elasticsearch - elasticsearch
- elasticsearch-ja
restart: always restart: always
volumes: volumes:
- ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh
- dify_es01_data:/usr/share/elasticsearch/data - dify_es01_data:/usr/share/elasticsearch/data
environment: environment:
ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
VECTOR_STORE: ${VECTOR_STORE:-}
cluster.name: dify-es-cluster cluster.name: dify-es-cluster
node.name: dify-es0 node.name: dify-es0
discovery.type: single-node discovery.type: single-node
xpack.license.self_generated.type: trial xpack.license.self_generated.type: basic
xpack.security.enabled: 'true' xpack.security.enabled: 'true'
xpack.security.enrollment.enabled: 'false' xpack.security.enrollment.enabled: 'false'
xpack.security.http.ssl.enabled: 'false' xpack.security.http.ssl.enabled: 'false'
ports: ports:
- ${ELASTICSEARCH_PORT:-9200}:9200 - ${ELASTICSEARCH_PORT:-9200}:9200
deploy:
resources:
limits:
memory: 2g
entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ]
healthcheck: healthcheck:
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 50 retries: 50
@ -991,7 +994,7 @@ services:
ports: ports:
- ${KIBANA_PORT:-5601}:5601 - ${KIBANA_PORT:-5601}:5601
healthcheck: healthcheck:
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3

View File

@ -0,0 +1,25 @@
#!/bin/bash
set -e
if [ "${VECTOR_STORE}" = "elasticsearch-ja" ]; then
# Check if the ICU tokenizer plugin is installed
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-icu; then
printf '%s\n' "Installing the ICU tokenizer plugin"
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-icu; then
printf '%s\n' "Failed to install the ICU tokenizer plugin"
exit 1
fi
fi
# Check if the Japanese language analyzer plugin is installed
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-kuromoji; then
printf '%s\n' "Installing the Japanese language analyzer plugin"
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-kuromoji; then
printf '%s\n' "Failed to install the Japanese language analyzer plugin"
exit 1
fi
fi
fi
# Run the original entrypoint script
exec /bin/tini -- /usr/local/bin/docker-entrypoint.sh

View File

@ -25,3 +25,6 @@ NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
NEXT_PUBLIC_CSP_WHITELIST= NEXT_PUBLIC_CSP_WHITELIST=
# The maximum number of top-k value for RAG.
NEXT_PUBLIC_TOP_K_MAX_VALUE=10

View File

@ -25,16 +25,18 @@ import Input from '@/app/components/base/input'
import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
import TagManagementModal from '@/app/components/base/tag-management' import TagManagementModal from '@/app/components/base/tag-management'
import TagFilter from '@/app/components/base/tag-management/filter' import TagFilter from '@/app/components/base/tag-management/filter'
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
const getKey = ( const getKey = (
pageIndex: number, pageIndex: number,
previousPageData: AppListResponse, previousPageData: AppListResponse,
activeTab: string, activeTab: string,
isCreatedByMe: boolean,
tags: string[], tags: string[],
keywords: string, keywords: string,
) => { ) => {
if (!pageIndex || previousPageData.has_more) { if (!pageIndex || previousPageData.has_more) {
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords } } const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } }
if (activeTab !== 'all') if (activeTab !== 'all')
params.params.mode = activeTab params.params.mode = activeTab
@ -58,6 +60,7 @@ const Apps = () => {
defaultTab: 'all', defaultTab: 'all',
}) })
const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState() const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState()
const [isCreatedByMe, setIsCreatedByMe] = useState(false)
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs) const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
const [searchKeywords, setSearchKeywords] = useState(keywords) const [searchKeywords, setSearchKeywords] = useState(keywords)
const setKeywords = useCallback((keywords: string) => { const setKeywords = useCallback((keywords: string) => {
@ -68,7 +71,7 @@ const Apps = () => {
}, [setQuery]) }, [setQuery])
const { data, isLoading, setSize, mutate } = useSWRInfinite( const { data, isLoading, setSize, mutate } = useSWRInfinite(
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords), (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords),
fetchAppList, fetchAppList,
{ revalidateFirstPage: true }, { revalidateFirstPage: true },
) )
@ -132,6 +135,12 @@ const Apps = () => {
options={options} options={options}
/> />
<div className='flex items-center gap-2'> <div className='flex items-center gap-2'>
<CheckboxWithLabel
className='mr-2'
label={t('app.showMyCreatedAppsOnly')}
isChecked={isCreatedByMe}
onChange={() => setIsCreatedByMe(!isCreatedByMe)}
/>
<TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} /> <TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} />
<Input <Input
showLeftIcon showLeftIcon

View File

@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
- <code>economy</code> Economy: Build using inverted index of keyword table index - <code>economy</code> Economy: Build using inverted index of keyword table index
</Property> </Property>
<Property name='doc_form' type='string' key='doc_form'>
Format of indexed content
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
- <code>hierarchical_model</code> Parent-child mode
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
</Property>
<Property name='doc_language' type='string' key='doc_language'>
In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
</Property>
<Property name='process_rule' type='object' key='process_rule'> <Property name='process_rule' type='object' key='process_rule'>
Processing rules Processing rules
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
@ -65,6 +74,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) Segmentation rules - <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000 - <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
- <code>economy</code> Economy: Build using inverted index of keyword table index - <code>economy</code> Economy: Build using inverted index of keyword table index
- <code>doc_form</code> Format of indexed content
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
- <code>hierarchical_model</code> Parent-child mode
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
- <code>doc_language</code> In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
- <code>process_rule</code> Processing rules - <code>process_rule</code> Processing rules
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
- <code>rules</code> (object) Custom rules (in automatic mode, this field is empty) - <code>rules</code> (object) Custom rules (in automatic mode, this field is empty)
@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) Segmentation rules - <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000 - <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property> </Property>
<Property name='file' type='multipart/form-data' key='file'> <Property name='file' type='multipart/form-data' key='file'>
Files that need to be uploaded. Files that need to be uploaded.
@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) Segmentation rules - <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000 - <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) Segmentation rules - <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000 - <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
@ -984,7 +1020,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Heading <Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
method='POST' method='POST'
title='Update a Chunk in a Document ' title='Update a Chunk in a Document'
name='#update_segment' name='#update_segment'
/> />
<Row> <Row>
@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional) - <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional)
- <code>keywords</code> (list) Keyword (optional) - <code>keywords</code> (list) Keyword (optional)
- <code>enabled</code> (bool) False / true (optional) - <code>enabled</code> (bool) False / true (optional)
- <code>regenerate_child_chunks</code> (bool) Whether to regenerate child chunks (optional)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>

View File

@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
</Property> </Property>
<Property name='doc_form' type='string' key='doc_form'>
索引内容的形式
- <code>text_model</code> text 文档直接 embedding经济模式默认为该模式
- <code>hierarchical_model</code> parent-child 模式
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
</Property>
<Property name='doc_language' type='string' key='doc_language'>
在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
</Property>
<Property name='process_rule' type='object' key='process_rule'> <Property name='process_rule' type='object' key='process_rule'>
处理规则 处理规则
- <code>mode</code> (string) 清洗、分段模式 automatic 自动 / custom 自定义 - <code>mode</code> (string) 清洗、分段模式 automatic 自动 / custom 自定义
@ -63,8 +72,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>remove_urls_emails</code> 删除 URL、电子邮件地址 - <code>remove_urls_emails</code> 删除 URL、电子邮件地址
- <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值
- <code>segmentation</code> (object) 分段规则 - <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 <code>\n</code>
- <code>max_tokens</code> 最大长度token默认为 1000 - <code>max_tokens</code> 最大长度token默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
- <code>doc_form</code> 索引内容的形式
- <code>text_model</code> text 文档直接 embedding经济模式默认为该模式
- <code>hierarchical_model</code> parent-child 模式
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
- <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
- <code>process_rule</code> 处理规则 - <code>process_rule</code> 处理规则
- <code>mode</code> (string) 清洗、分段模式 automatic 自动 / custom 自定义 - <code>mode</code> (string) 清洗、分段模式 automatic 自动 / custom 自定义
- <code>rules</code> (object) 自定义规则(自动模式下,该字段为空) - <code>rules</code> (object) 自定义规则(自动模式下,该字段为空)
@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) 分段规则 - <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度token默认为 1000 - <code>max_tokens</code> 最大长度token默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property> </Property>
<Property name='file' type='multipart/form-data' key='file'> <Property name='file' type='multipart/form-data' key='file'>
需要上传的文件。 需要上传的文件。
@ -411,7 +437,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Heading <Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text' url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
method='POST' method='POST'
title='通过文本更新文档 ' title='通过文本更新文档'
name='#update-by-text' name='#update-by-text'
/> />
<Row> <Row>
@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) 分段规则 - <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度token默认为 1000 - <code>max_tokens</code> 最大长度token默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
@ -508,7 +539,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Heading <Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file' url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
method='POST' method='POST'
title='通过文件更新文档 ' title='通过文件更新文档'
name='#update-by-file' name='#update-by-file'
/> />
<Row> <Row>
@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>segmentation</code> (object) 分段规则 - <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度token默认为 1000 - <code>max_tokens</code> 最大长度token默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
- <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值
- <code>keywords</code> (list) 关键字,非必填 - <code>keywords</code> (list) 关键字,非必填
- <code>enabled</code> (bool) false/true非必填 - <code>enabled</code> (bool) false/true非必填
- <code>regenerate_child_chunks</code> (bool) 是否重新生成子分段,非必填
</Property> </Property>
</Properties> </Properties>
</Col> </Col>

View File

@ -26,13 +26,15 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({
const [clientY, setClientY] = useState(0) const [clientY, setClientY] = useState(0)
const [isResizing, setIsResizing] = useState(false) const [isResizing, setIsResizing] = useState(false)
const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect) const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect)
const [oldHeight, setOldHeight] = useState(height)
const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => { const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => {
setClientY(e.clientY) setClientY(e.clientY)
setIsResizing(true) setIsResizing(true)
setOldHeight(height)
setPrevUserSelectStyle(getComputedStyle(document.body).userSelect) setPrevUserSelectStyle(getComputedStyle(document.body).userSelect)
document.body.style.userSelect = 'none' document.body.style.userSelect = 'none'
}, []) }, [height])
const handleStopResize = useCallback(() => { const handleStopResize = useCallback(() => {
setIsResizing(false) setIsResizing(false)
@ -44,8 +46,7 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({
return return
const offset = e.clientY - clientY const offset = e.clientY - clientY
let newHeight = height + offset let newHeight = oldHeight + offset
setClientY(e.clientY)
if (newHeight < minHeight) if (newHeight < minHeight)
newHeight = minHeight newHeight = minHeight
onHeightChange(newHeight) onHeightChange(newHeight)

View File

@ -27,6 +27,7 @@ import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/confi
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block'
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useFeaturesStore } from '@/app/components/base/features/hooks'
export type ISimplePromptInput = { export type ISimplePromptInput = {
mode: AppType mode: AppType
@ -54,6 +55,11 @@ const Prompt: FC<ISimplePromptInput> = ({
const { t } = useTranslation() const { t } = useTranslation()
const media = useBreakpoints() const media = useBreakpoints()
const isMobile = media === MediaType.mobile const isMobile = media === MediaType.mobile
const featuresStore = useFeaturesStore()
const {
features,
setFeatures,
} = featuresStore!.getState()
const { eventEmitter } = useEventEmitterContextContext() const { eventEmitter } = useEventEmitterContextContext()
const { const {
@ -137,8 +143,18 @@ const Prompt: FC<ISimplePromptInput> = ({
}) })
setModelConfig(newModelConfig) setModelConfig(newModelConfig)
setPrevPromptConfig(modelConfig.configs) setPrevPromptConfig(modelConfig.configs)
if (mode !== AppType.completion)
if (mode !== AppType.completion) {
setIntroduction(res.opening_statement) setIntroduction(res.opening_statement)
const newFeatures = produce(features, (draft) => {
draft.opening = {
...draft.opening,
enabled: !!res.opening_statement,
opening_statement: res.opening_statement,
}
})
setFeatures(newFeatures)
}
showAutomaticFalse() showAutomaticFalse()
} }
const minHeight = initEditorHeight || 228 const minHeight = initEditorHeight || 228

View File

@ -59,36 +59,24 @@ const ConfigContent: FC<Props> = ({
const { const {
modelList: rerankModelList, modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const { const {
currentModel: currentRerankModel, currentModel: currentRerankModel,
} = useCurrentProviderAndModel( } = useCurrentProviderAndModel(
rerankModelList, rerankModelList,
rerankDefaultModel {
? { provider: datasetConfigs.reranking_model?.reranking_provider_name,
...rerankDefaultModel, model: datasetConfigs.reranking_model?.reranking_model_name,
provider: rerankDefaultModel.provider.provider, },
}
: undefined,
) )
const rerankModel = (() => { const rerankModel = useMemo(() => {
if (datasetConfigs.reranking_model?.reranking_provider_name) { return {
return { provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '',
provider_name: datasetConfigs.reranking_model.reranking_provider_name, model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '',
model_name: datasetConfigs.reranking_model.reranking_model_name,
}
} }
else if (rerankDefaultModel) { }, [datasetConfigs.reranking_model])
return {
provider_name: rerankDefaultModel.provider.provider,
model_name: rerankDefaultModel.model,
}
}
})()
const handleParamChange = (key: string, value: number) => { const handleParamChange = (key: string, value: number) => {
if (key === 'top_k') { if (key === 'top_k') {
@ -133,6 +121,12 @@ const ConfigContent: FC<Props> = ({
} }
const handleRerankModeChange = (mode: RerankingModeEnum) => { const handleRerankModeChange = (mode: RerankingModeEnum) => {
if (mode === datasetConfigs.reranking_mode)
return
if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
onChange({ onChange({
...datasetConfigs, ...datasetConfigs,
reranking_mode: mode, reranking_mode: mode,
@ -162,31 +156,25 @@ const ConfigContent: FC<Props> = ({
const canManuallyToggleRerank = useMemo(() => { const canManuallyToggleRerank = useMemo(() => {
return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|| selectedDatasetsMode.allExternal || selectedDatasetsMode.allExternal
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
const showRerankModel = useMemo(() => { const showRerankModel = useMemo(() => {
if (!canManuallyToggleRerank) if (!canManuallyToggleRerank)
return true return true
else if (canManuallyToggleRerank && !isRerankDefaultModelValid)
return false
return datasetConfigs.reranking_enable return datasetConfigs.reranking_enable
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid]) }, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
const handleDisabledSwitchClick = useCallback(() => { const handleDisabledSwitchClick = useCallback((enable: boolean) => {
if (!currentRerankModel && !showRerankModel) if (!currentRerankModel && enable)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentRerankModel, showRerankModel, t]) onChange({
...datasetConfigs,
useEffect(() => { reranking_enable: enable,
if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) { })
onChange({ // eslint-disable-next-line react-hooks/exhaustive-deps
...datasetConfigs, }, [currentRerankModel, datasetConfigs, onChange])
reranking_enable: showRerankModel,
})
}
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
return ( return (
<div> <div>
@ -267,24 +255,12 @@ const ConfigContent: FC<Props> = ({
<div className='flex items-center'> <div className='flex items-center'>
{ {
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
<div <Switch
className='flex items-center' size='md'
onClick={handleDisabledSwitchClick} defaultValue={showRerankModel}
> disabled={!canManuallyToggleRerank}
<Switch onChange={handleDisabledSwitchClick}
size='md' />
defaultValue={showRerankModel}
disabled={!currentRerankModel || !canManuallyToggleRerank}
onChange={(v) => {
if (canManuallyToggleRerank) {
onChange({
...datasetConfigs,
reranking_enable: v,
})
}
}}
/>
</div>
) )
} }
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div> <div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
@ -298,21 +274,24 @@ const ConfigContent: FC<Props> = ({
triggerClassName='ml-1 w-4 h-4' triggerClassName='ml-1 w-4 h-4'
/> />
</div> </div>
<div> {
<ModelSelector showRerankModel && (
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }} <div>
onSelect={(v) => { <ModelSelector
onChange({ defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
...datasetConfigs, onSelect={(v) => {
reranking_model: { onChange({
reranking_provider_name: v.provider, ...datasetConfigs,
reranking_model_name: v.model, reranking_model: {
}, reranking_provider_name: v.provider,
}) reranking_model_name: v.model,
}} },
modelList={rerankModelList} })
/> }}
</div> modelList={rerankModelList}
/>
</div>
)}
</div> </div>
) )
} }

View File

@ -10,7 +10,7 @@ import Modal from '@/app/components/base/modal'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import { RETRIEVE_TYPE } from '@/types/app' import { RETRIEVE_TYPE } from '@/types/app'
import Toast from '@/app/components/base/toast' import Toast from '@/app/components/base/toast'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { RerankingModeEnum } from '@/models/datasets' import { RerankingModeEnum } from '@/models/datasets'
import type { DataSet } from '@/models/datasets' import type { DataSet } from '@/models/datasets'
@ -41,17 +41,27 @@ const ParamsConfig = ({
}, [datasetConfigs]) }, [datasetConfigs])
const { const {
defaultModel: rerankDefaultModel, modelList: rerankModelList,
currentModel: isRerankDefaultModelValid, currentModel: rerankDefaultModel,
currentProvider: rerankDefaultProvider, currentProvider: rerankDefaultProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: isCurrentRerankModelValid,
} = useCurrentProviderAndModel(
rerankModelList,
{
provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '',
model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '',
},
)
const isValid = () => { const isValid = () => {
let errMsg = '' let errMsg = ''
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
if (tempDataSetConfigs.reranking_enable if (tempDataSetConfigs.reranking_enable
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
&& !isRerankDefaultModelValid && !isCurrentRerankModelValid
) )
errMsg = t('appDebug.datasetConfig.rerankModelRequired') errMsg = t('appDebug.datasetConfig.rerankModelRequired')
} }
@ -66,16 +76,7 @@ const ParamsConfig = ({
const handleSave = () => { const handleSave = () => {
if (!isValid()) if (!isValid())
return return
const config = { ...tempDataSetConfigs } setDatasetConfigs(tempDataSetConfigs)
if (config.retrieval_model === RETRIEVE_TYPE.multiWay
&& config.reranking_mode === RerankingModeEnum.RerankingModel
&& !config.reranking_model) {
config.reranking_model = {
reranking_provider_name: rerankDefaultModel?.provider?.provider,
reranking_model_name: rerankDefaultModel?.model,
} as any
}
setDatasetConfigs(config)
setRerankSettingModalOpen(false) setRerankSettingModalOpen(false)
} }
@ -94,14 +95,14 @@ const ParamsConfig = ({
reranking_enable: restConfigs.reranking_enable, reranking_enable: restConfigs.reranking_enable,
}, selectedDatasets, selectedDatasets, { }, selectedDatasets, selectedDatasets, {
provider: rerankDefaultProvider?.provider, provider: rerankDefaultProvider?.provider,
model: isRerankDefaultModelValid?.model, model: rerankDefaultModel?.model,
}) })
setTempDataSetConfigs({ setTempDataSetConfigs({
...retrievalConfig, ...retrievalConfig,
reranking_model: restConfigs.reranking_model && { reranking_model: {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, reranking_provider_name: retrievalConfig.reranking_model?.provider || '',
reranking_model_name: restConfigs.reranking_model.reranking_model_name, reranking_model_name: retrievalConfig.reranking_model?.model || '',
}, },
retrieval_model, retrieval_model,
score_threshold_enabled, score_threshold_enabled,

View File

@ -29,7 +29,7 @@ const WeightedScore = ({
return ( return (
<div> <div>
<div className='px-3 pt-5 h-[52px] space-x-3 rounded-lg border border-components-panel-border'> <div className='px-3 pt-5 pb-2 space-x-3 rounded-lg border border-components-panel-border'>
<Slider <Slider
className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')} className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')}
max={1.0} max={1.0}
@ -39,7 +39,7 @@ const WeightedScore = ({
onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })} onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })}
trackClassName='weightedScoreSliderTrack' trackClassName='weightedScoreSliderTrack'
/> />
<div className='flex justify-between mt-1'> <div className='flex justify-between mt-3'>
<div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'> <div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'>
<div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}> <div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}>
{t('dataset.weightedScore.semantic')} {t('dataset.weightedScore.semantic')}

View File

@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input' import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea' import Textarea from '@/app/components/base/textarea'
import { type DataSet, RerankingModeEnum } from '@/models/datasets' import { type DataSet } from '@/models/datasets'
import { useToastContext } from '@/app/components/base/toast' import { useToastContext } from '@/app/components/base/toast'
import { updateDatasetSetting } from '@/service/datasets' import { updateDatasetSetting } from '@/service/datasets'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
@ -21,7 +21,7 @@ import type { RetrievalConfig } from '@/types/app'
import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import PermissionSelector from '@/app/components/datasets/settings/permission-selector' import PermissionSelector from '@/app/components/datasets/settings/permission-selector'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
@ -99,8 +99,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
} }
if ( if (
!isReRankModelSelected({ !isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
rerankModelList, rerankModelList,
retrievalConfig, retrievalConfig,
indexMethod, indexMethod,
@ -109,14 +107,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return return
} }
const postRetrievalConfig = ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig: {
...retrievalConfig,
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
},
indexMethod,
})
try { try {
setLoading(true) setLoading(true)
const { id, name, description, permission } = localeCurrentDataset const { id, name, description, permission } = localeCurrentDataset
@ -128,8 +118,8 @@ const SettingsModal: FC<SettingsModalProps> = ({
permission, permission,
indexing_technique: indexMethod, indexing_technique: indexMethod,
retrieval_model: { retrieval_model: {
...postRetrievalConfig, ...retrievalConfig,
score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0, score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0,
}, },
embedding_model: localeCurrentDataset.embedding_model, embedding_model: localeCurrentDataset.embedding_model,
embedding_model_provider: localeCurrentDataset.embedding_model_provider, embedding_model_provider: localeCurrentDataset.embedding_model_provider,
@ -157,7 +147,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
onSave({ onSave({
...localeCurrentDataset, ...localeCurrentDataset,
indexing_technique: indexMethod, indexing_technique: indexMethod,
retrieval_model_dict: postRetrievalConfig, retrieval_model_dict: retrievalConfig,
}) })
} }
catch (e) { catch (e) {

View File

@ -287,9 +287,9 @@ const Configuration: FC = () => {
setDatasetConfigs({ setDatasetConfigs({
...retrievalConfig, ...retrievalConfig,
reranking_model: restConfigs.reranking_model && { reranking_model: {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
reranking_model_name: restConfigs.reranking_model.reranking_model_name, reranking_model_name: retrievalConfig?.reranking_model?.model || '',
}, },
retrieval_model, retrieval_model,
score_threshold_enabled, score_threshold_enabled,

View File

@ -39,6 +39,7 @@ type ChatInputAreaProps = {
inputs?: Record<string, any> inputs?: Record<string, any>
inputsForm?: InputForm[] inputsForm?: InputForm[]
theme?: Theme | null theme?: Theme | null
isResponding?: boolean
} }
const ChatInputArea = ({ const ChatInputArea = ({
showFeatureBar, showFeatureBar,
@ -51,6 +52,7 @@ const ChatInputArea = ({
inputs = {}, inputs = {},
inputsForm = [], inputsForm = [],
theme, theme,
isResponding,
}: ChatInputAreaProps) => { }: ChatInputAreaProps) => {
const { t } = useTranslation() const { t } = useTranslation()
const { notify } = useToastContext() const { notify } = useToastContext()
@ -77,6 +79,11 @@ const ChatInputArea = ({
const historyRef = useRef(['']) const historyRef = useRef([''])
const [currentIndex, setCurrentIndex] = useState(-1) const [currentIndex, setCurrentIndex] = useState(-1)
const handleSend = () => { const handleSend = () => {
if (isResponding) {
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
return
}
if (onSend) { if (onSend) {
const { files, setFiles } = filesStore.getState() const { files, setFiles } = filesStore.getState()
if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) { if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) {
@ -116,7 +123,7 @@ const ChatInputArea = ({
setQuery(historyRef.current[currentIndex + 1]) setQuery(historyRef.current[currentIndex + 1])
} }
else if (currentIndex === historyRef.current.length - 1) { else if (currentIndex === historyRef.current.length - 1) {
// If it is the last element, clear the input box // If it is the last element, clear the input box
setCurrentIndex(historyRef.current.length) setCurrentIndex(historyRef.current.length)
setQuery('') setQuery('')
} }
@ -169,6 +176,7 @@ const ChatInputArea = ({
'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none', 'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none',
)} )}
placeholder={t('common.chat.inputPlaceholder') || ''} placeholder={t('common.chat.inputPlaceholder') || ''}
autoFocus
autoSize={{ minRows: 1 }} autoSize={{ minRows: 1 }}
onResize={handleTextareaResize} onResize={handleTextareaResize}
value={query} value={query}

View File

@ -292,6 +292,7 @@ const Chat: FC<ChatProps> = ({
inputs={inputs} inputs={inputs}
inputsForm={inputsForm} inputsForm={inputsForm}
theme={themeBuilder?.theme} theme={themeBuilder?.theme}
isResponding={isResponding}
/> />
) )
} }

View File

@ -28,8 +28,8 @@ const Question: FC<QuestionProps> = ({
} = item } = item
return ( return (
<div className='flex justify-end mb-2 last:mb-0 pl-10'> <div className='flex justify-end mb-2 last:mb-0 pl-14'>
<div className='group relative mr-4'> <div className='group relative mr-4 max-w-full'>
<div <div
className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900' className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900'
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}} style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}

View File

@ -111,9 +111,9 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }
} }
else if (language === 'echarts') { else if (language === 'echarts') {
return ( return (
<div style={{ minHeight: '350px', minWidth: '700px' }}> <div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}>
<ErrorBoundary> <ErrorBoundary>
<ReactEcharts option={chartData} /> <ReactEcharts option={chartData} style={{ minWidth: '700px' }} />
</ErrorBoundary> </ErrorBoundary>
</div> </div>
) )

View File

@ -11,11 +11,17 @@ type Props = {
enable: boolean enable: boolean
} }
const maxTopK = (() => {
const configValue = parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10)
if (configValue && !isNaN(configValue))
return configValue
return 10
})()
const VALUE_LIMIT = { const VALUE_LIMIT = {
default: 2, default: 2,
step: 1, step: 1,
min: 1, min: 1,
max: 10, max: maxTopK,
} }
const key = 'top_k' const key = 'top_k'

View File

@ -6,14 +6,10 @@ import type {
import { RerankingModeEnum } from '@/models/datasets' import { RerankingModeEnum } from '@/models/datasets'
export const isReRankModelSelected = ({ export const isReRankModelSelected = ({
rerankDefaultModel,
isRerankDefaultModelValid,
retrievalConfig, retrievalConfig,
rerankModelList, rerankModelList,
indexMethod, indexMethod,
}: { }: {
rerankDefaultModel?: DefaultModelResponse
isRerankDefaultModelValid: boolean
retrievalConfig: RetrievalConfig retrievalConfig: RetrievalConfig
rerankModelList: Model[] rerankModelList: Model[]
indexMethod?: string indexMethod?: string
@ -25,12 +21,17 @@ export const isReRankModelSelected = ({
return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name) return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
} }
if (isRerankDefaultModelValid)
return !!rerankDefaultModel
return false return false
})() })()
if (
indexMethod === 'high_quality'
&& ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrievalConfig.search_method))
&& retrievalConfig.reranking_enable
&& !rerankModelSelected
)
return false
if ( if (
indexMethod === 'high_quality' indexMethod === 'high_quality'
&& (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore) && (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore)

View File

@ -10,11 +10,13 @@ import { RETRIEVE_METHOD } from '@/types/app'
import type { RetrievalConfig } from '@/types/app' import type { RetrievalConfig } from '@/types/app'
type Props = { type Props = {
disabled?: boolean
value: RetrievalConfig value: RetrievalConfig
onChange: (value: RetrievalConfig) => void onChange: (value: RetrievalConfig) => void
} }
const EconomicalRetrievalMethodConfig: FC<Props> = ({ const EconomicalRetrievalMethodConfig: FC<Props> = ({
disabled = false,
value, value,
onChange, onChange,
}) => { }) => {
@ -22,7 +24,8 @@ const EconomicalRetrievalMethodConfig: FC<Props> = ({
return ( return (
<div className='space-y-2'> <div className='space-y-2'>
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} <OptionCard
disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
title={t('dataset.retrieval.invertedIndex.title')} title={t('dataset.retrieval.invertedIndex.title')}
description={t('dataset.retrieval.invertedIndex.description')} isActive description={t('dataset.retrieval.invertedIndex.description')} isActive
activeHeaderClassName='bg-dataset-option-card-purple-gradient' activeHeaderClassName='bg-dataset-option-card-purple-gradient'

View File

@ -1,6 +1,6 @@
'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React from 'react' import React, { useCallback } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import Image from 'next/image' import Image from 'next/image'
import RetrievalParamConfig from '../retrieval-param-config' import RetrievalParamConfig from '../retrieval-param-config'
@ -10,7 +10,7 @@ import { retrievalIcon } from '../../create/icons'
import type { RetrievalConfig } from '@/types/app' import type { RetrievalConfig } from '@/types/app'
import { RETRIEVE_METHOD } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { import {
DEFAULT_WEIGHTED_SCORE, DEFAULT_WEIGHTED_SCORE,
@ -20,54 +20,87 @@ import {
import Badge from '@/app/components/base/badge' import Badge from '@/app/components/base/badge'
type Props = { type Props = {
disabled?: boolean
value: RetrievalConfig value: RetrievalConfig
onChange: (value: RetrievalConfig) => void onChange: (value: RetrievalConfig) => void
} }
const RetrievalMethodConfig: FC<Props> = ({ const RetrievalMethodConfig: FC<Props> = ({
value: passValue, disabled = false,
value,
onChange, onChange,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { supportRetrievalMethods } = useProviderContext() const { supportRetrievalMethods } = useProviderContext()
const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank) const {
const value = (() => { defaultModel: rerankDefaultModel,
if (!passValue.reranking_model.reranking_model_name) { currentModel: isRerankDefaultModelValid,
return { } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
...passValue,
reranking_model: { const onSwitch = useCallback((retrieveMethod: RETRIEVE_METHOD) => {
reranking_provider_name: rerankDefaultModel?.provider.provider || '', if ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrieveMethod)) {
reranking_model_name: rerankDefaultModel?.model || '', onChange({
}, ...value,
reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore), search_method: retrieveMethod,
weights: passValue.weights || { ...(!value.reranking_model.reranking_model_name
weight_type: WeightedScoreEnum.Customized, ? {
vector_setting: { reranking_model: {
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
embedding_provider_name: '', reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
embedding_model_name: '', },
}, reranking_enable: !!isRerankDefaultModelValid,
keyword_setting: { }
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, : {
}, reranking_enable: true,
}, }),
} })
} }
return passValue if (retrieveMethod === RETRIEVE_METHOD.hybrid) {
})() onChange({
...value,
search_method: retrieveMethod,
...(!value.reranking_model.reranking_model_name
? {
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
reranking_enable: !!isRerankDefaultModelValid,
reranking_mode: isRerankDefaultModelValid ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore,
}
: {
reranking_enable: true,
reranking_mode: RerankingModeEnum.RerankingModel,
}),
...(!value.weights
? {
weights: {
weight_type: WeightedScoreEnum.Customized,
vector_setting: {
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
embedding_provider_name: '',
embedding_model_name: '',
},
keyword_setting: {
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
},
},
}
: {}),
})
}
}, [value, rerankDefaultModel, isRerankDefaultModelValid, onChange])
return ( return (
<div className='space-y-2'> <div className='space-y-2'>
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
title={t('dataset.retrieval.semantic_search.title')} title={t('dataset.retrieval.semantic_search.title')}
description={t('dataset.retrieval.semantic_search.description')} description={t('dataset.retrieval.semantic_search.description')}
isActive={ isActive={
value.search_method === RETRIEVE_METHOD.semantic value.search_method === RETRIEVE_METHOD.semantic
} }
onSwitched={() => onChange({ onSwitched={() => onSwitch(RETRIEVE_METHOD.semantic)}
...value,
search_method: RETRIEVE_METHOD.semantic,
})}
effectImg={Effect.src} effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient' activeHeaderClassName='bg-dataset-option-card-purple-gradient'
> >
@ -78,17 +111,14 @@ const RetrievalMethodConfig: FC<Props> = ({
/> />
</OptionCard> </OptionCard>
)} )}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( {supportRetrievalMethods.includes(RETRIEVE_METHOD.fullText) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />} <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
title={t('dataset.retrieval.full_text_search.title')} title={t('dataset.retrieval.full_text_search.title')}
description={t('dataset.retrieval.full_text_search.description')} description={t('dataset.retrieval.full_text_search.description')}
isActive={ isActive={
value.search_method === RETRIEVE_METHOD.fullText value.search_method === RETRIEVE_METHOD.fullText
} }
onSwitched={() => onChange({ onSwitched={() => onSwitch(RETRIEVE_METHOD.fullText)}
...value,
search_method: RETRIEVE_METHOD.fullText,
})}
effectImg={Effect.src} effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient' activeHeaderClassName='bg-dataset-option-card-purple-gradient'
> >
@ -99,8 +129,8 @@ const RetrievalMethodConfig: FC<Props> = ({
/> />
</OptionCard> </OptionCard>
)} )}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( {supportRetrievalMethods.includes(RETRIEVE_METHOD.hybrid) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />} <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
title={ title={
<div className='flex items-center space-x-1'> <div className='flex items-center space-x-1'>
<div>{t('dataset.retrieval.hybrid_search.title')}</div> <div>{t('dataset.retrieval.hybrid_search.title')}</div>
@ -110,11 +140,7 @@ const RetrievalMethodConfig: FC<Props> = ({
description={t('dataset.retrieval.hybrid_search.description')} isActive={ description={t('dataset.retrieval.hybrid_search.description')} isActive={
value.search_method === RETRIEVE_METHOD.hybrid value.search_method === RETRIEVE_METHOD.hybrid
} }
onSwitched={() => onChange({ onSwitched={() => onSwitch(RETRIEVE_METHOD.hybrid)}
...value,
search_method: RETRIEVE_METHOD.hybrid,
reranking_enable: true,
})}
effectImg={Effect.src} effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient' activeHeaderClassName='bg-dataset-option-card-purple-gradient'
> >

View File

@ -1,6 +1,6 @@
'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React, { useCallback } from 'react' import React, { useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import Image from 'next/image' import Image from 'next/image'
@ -39,8 +39,8 @@ const RetrievalParamConfig: FC<Props> = ({
const { t } = useTranslation() const { t } = useTranslation()
const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid
const isEconomical = type === RETRIEVE_METHOD.invertedIndex const isEconomical = type === RETRIEVE_METHOD.invertedIndex
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
const { const {
defaultModel: rerankDefaultModel,
modelList: rerankModelList, modelList: rerankModelList,
} = useModelListAndDefaultModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModel(ModelTypeEnum.rerank)
@ -48,35 +48,28 @@ const RetrievalParamConfig: FC<Props> = ({
currentModel, currentModel,
} = useCurrentProviderAndModel( } = useCurrentProviderAndModel(
rerankModelList, rerankModelList,
rerankDefaultModel {
? { provider: value.reranking_model?.reranking_provider_name ?? '',
...rerankDefaultModel, model: value.reranking_model?.reranking_model_name ?? '',
provider: rerankDefaultModel.provider.provider, },
}
: undefined,
) )
const handleDisabledSwitchClick = useCallback(() => { const handleDisabledSwitchClick = useCallback((enable: boolean) => {
if (!currentModel) if (enable && !currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentModel, rerankDefaultModel, t]) onChange({
...value,
reranking_enable: enable,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentModel, onChange, value])
const isHybridSearch = type === RETRIEVE_METHOD.hybrid const rerankModel = useMemo(() => {
return {
const rerankModel = (() => { provider_name: value.reranking_model.reranking_provider_name,
if (value.reranking_model) { model_name: value.reranking_model.reranking_model_name,
return {
provider_name: value.reranking_model.reranking_provider_name,
model_name: value.reranking_model.reranking_model_name,
}
} }
else if (rerankDefaultModel) { }, [value.reranking_model])
return {
provider_name: rerankDefaultModel.provider.provider,
model_name: rerankDefaultModel.model,
}
}
})()
const handleChangeRerankMode = (v: RerankingModeEnum) => { const handleChangeRerankMode = (v: RerankingModeEnum) => {
if (v === value.reranking_mode) if (v === value.reranking_mode)
@ -100,6 +93,8 @@ const RetrievalParamConfig: FC<Props> = ({
}, },
} }
} }
if (v === RerankingModeEnum.RerankingModel && !currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
onChange(result) onChange(result)
} }
@ -122,22 +117,11 @@ const RetrievalParamConfig: FC<Props> = ({
<div> <div>
<div className='flex items-center space-x-2 mb-2'> <div className='flex items-center space-x-2 mb-2'>
{canToggleRerankModalEnable && ( {canToggleRerankModalEnable && (
<div <Switch
className='flex items-center' size='md'
onClick={handleDisabledSwitchClick} defaultValue={value.reranking_enable}
> onChange={handleDisabledSwitchClick}
<Switch />
size='md'
defaultValue={currentModel ? value.reranking_enable : false}
onChange={(v) => {
onChange({
...value,
reranking_enable: v,
})
}}
disabled={!currentModel}
/>
</div>
)} )}
<div className='flex items-center'> <div className='flex items-center'>
<span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span> <span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span>
@ -148,21 +132,23 @@ const RetrievalParamConfig: FC<Props> = ({
/> />
</div> </div>
</div> </div>
<ModelSelector {
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`} value.reranking_enable && (
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} <ModelSelector
modelList={rerankModelList} defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
readonly={!value.reranking_enable} modelList={rerankModelList}
onSelect={(v) => { onSelect={(v) => {
onChange({ onChange({
...value, ...value,
reranking_model: { reranking_model: {
reranking_provider_name: v.provider, reranking_provider_name: v.provider,
reranking_model_name: v.model, reranking_model_name: v.model,
}, },
}) })
}} }}
/> />
)
}
</div> </div>
)} )}
{ {
@ -255,10 +241,8 @@ const RetrievalParamConfig: FC<Props> = ({
{ {
value.reranking_mode !== RerankingModeEnum.WeightedScore && ( value.reranking_mode !== RerankingModeEnum.WeightedScore && (
<ModelSelector <ModelSelector
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
modelList={rerankModelList} modelList={rerankModelList}
readonly={!value.reranking_enable}
onSelect={(v) => { onSelect={(v) => {
onChange({ onChange({
...value, ...value,

View File

@ -30,6 +30,7 @@ import { useProviderContext } from '@/context/provider-context'
import { sleep } from '@/utils' import { sleep } from '@/utils'
import { RETRIEVE_METHOD } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import { useInvalidDocumentList } from '@/service/knowledge/use-document'
type Props = { type Props = {
datasetId: string datasetId: string
@ -207,7 +208,9 @@ const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], index
}) })
const router = useRouter() const router = useRouter()
const invalidDocumentList = useInvalidDocumentList()
const navToDocumentList = () => { const navToDocumentList = () => {
invalidDocumentList()
router.push(`/datasets/${datasetId}/documents`) router.push(`/datasets/${datasetId}/documents`)
} }
const navToApiDocs = () => { const navToApiDocs = () => {

View File

@ -31,17 +31,17 @@ import LanguageSelect from './language-select'
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs' import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
import { ChunkingMode, DataSourceType, ProcessMode } from '@/models/datasets'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import FloatRightContainer from '@/app/components/base/float-right-container' import FloatRightContainer from '@/app/components/base/float-right-container'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import { type RetrievalConfig } from '@/types/app' import { type RetrievalConfig } from '@/types/app'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import Toast from '@/app/components/base/toast' import Toast from '@/app/components/base/toast'
import type { NotionPage } from '@/models/common' import type { NotionPage } from '@/models/common'
import { DataSourceProvider } from '@/models/common' import { DataSourceProvider } from '@/models/common'
import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets'
import { useDatasetDetailContext } from '@/context/dataset-detail' import { useDatasetDetailContext } from '@/context/dataset-detail'
import I18n from '@/context/i18n' import I18n from '@/context/i18n'
import { RETRIEVE_METHOD } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app'
@ -53,7 +53,7 @@ import type { DefaultModel } from '@/app/components/header/account-setting/model
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import Checkbox from '@/app/components/base/checkbox' import Checkbox from '@/app/components/base/checkbox'
import RadioCard from '@/app/components/base/radio-card' import RadioCard from '@/app/components/base/radio-card'
import { IS_CE_EDITION } from '@/config' import { FULL_DOC_PREVIEW_LENGTH, IS_CE_EDITION } from '@/config'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset' import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset'
import Badge from '@/app/components/base/badge' import Badge from '@/app/components/base/badge'
@ -90,17 +90,13 @@ type StepTwoProps = {
onCancel?: () => void onCancel?: () => void
} }
export enum SegmentType {
AUTO = 'automatic',
CUSTOM = 'custom',
}
export enum IndexingType { export enum IndexingType {
QUALIFIED = 'high_quality', QUALIFIED = 'high_quality',
ECONOMICAL = 'economy', ECONOMICAL = 'economy',
} }
const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500 const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500
const DEFAULT_OVERLAP = 50 const DEFAULT_OVERLAP = 50
type ParentChildConfig = { type ParentChildConfig = {
@ -131,7 +127,6 @@ const StepTwo = ({
isSetting, isSetting,
documentDetail, documentDetail,
isAPIKeySet, isAPIKeySet,
onSetting,
datasetId, datasetId,
indexingType, indexingType,
dataSourceType: inCreatePageDataSourceType, dataSourceType: inCreatePageDataSourceType,
@ -162,12 +157,12 @@ const StepTwo = ({
const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type) const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type)
const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type
const [segmentationType, setSegmentationType] = useState<SegmentType>(SegmentType.CUSTOM) const [segmentationType, setSegmentationType] = useState<ProcessMode>(ProcessMode.general)
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER) const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => { const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)) doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER))
}, []) }, [])
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH) // default chunk length
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000) const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000)
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP) const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
const [rules, setRules] = useState<PreProcessingRule[]>([]) const [rules, setRules] = useState<PreProcessingRule[]>([])
@ -198,7 +193,6 @@ const StepTwo = ({
) )
// QA Related // QA Related
const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false)
const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false) const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false)
const [docForm, setDocForm] = useState<ChunkingMode>( const [docForm, setDocForm] = useState<ChunkingMode>(
(datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text, (datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text,
@ -348,7 +342,7 @@ const StepTwo = ({
} }
const updatePreview = () => { const updatePreview = () => {
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) { if (segmentationType === ProcessMode.general && maxChunkLength > 4000) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') })
return return
} }
@ -373,13 +367,42 @@ const StepTwo = ({
model: defaultEmbeddingModel?.model || '', model: defaultEmbeddingModel?.model || '',
}, },
) )
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
} as RetrievalConfig)
useEffect(() => {
if (currentDataset?.retrieval_model_dict)
return
setRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: !!isRerankDefaultModelValid,
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [rerankDefaultModel, isRerankDefaultModelValid])
const getCreationParams = () => { const getCreationParams = () => {
let params let params
if (segmentationType === SegmentType.CUSTOM && overlap > maxChunkLength) { if (segmentationType === ProcessMode.general && overlap > maxChunkLength) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') }) Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') })
return return
} }
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) { if (segmentationType === ProcessMode.general && maxChunkLength > limitMaxChunkLength) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) }) Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) })
return return
} }
@ -389,7 +412,6 @@ const StepTwo = ({
doc_form: currentDocForm, doc_form: currentDocForm,
doc_language: docLanguage, doc_language: docLanguage,
process_rule: getProcessRule(), process_rule: getProcessRule(),
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page.
embedding_model: embeddingModel.model, // Readonly embedding_model: embeddingModel.model, // Readonly
embedding_model_provider: embeddingModel.provider, // Readonly embedding_model_provider: embeddingModel.provider, // Readonly
@ -400,10 +422,7 @@ const StepTwo = ({
const indexMethod = getIndexing_technique() const indexMethod = getIndexing_technique()
if ( if (
!isReRankModelSelected({ !isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
rerankModelList, rerankModelList,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrievalConfig, retrievalConfig,
indexMethod: indexMethod as string, indexMethod: indexMethod as string,
}) })
@ -411,16 +430,6 @@ const StepTwo = ({
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return return
} }
const postRetrievalConfig = ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig: {
// eslint-disable-next-line @typescript-eslint/no-use-before-define
...retrievalConfig,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
},
indexMethod: indexMethod as string,
})
params = { params = {
data_source: { data_source: {
type: dataSourceType, type: dataSourceType,
@ -432,8 +441,7 @@ const StepTwo = ({
process_rule: getProcessRule(), process_rule: getProcessRule(),
doc_form: currentDocForm, doc_form: currentDocForm,
doc_language: docLanguage, doc_language: docLanguage,
retrieval_model: retrievalConfig,
retrieval_model: postRetrievalConfig,
embedding_model: embeddingModel.model, embedding_model: embeddingModel.model,
embedding_model_provider: embeddingModel.provider, embedding_model_provider: embeddingModel.provider,
} as CreateDocumentReq } as CreateDocumentReq
@ -490,7 +498,6 @@ const StepTwo = ({
const getDefaultMode = () => { const getDefaultMode = () => {
if (documentDetail) if (documentDetail)
// @ts-expect-error fix after api refactored
setSegmentationType(documentDetail.dataset_process_rule.mode) setSegmentationType(documentDetail.dataset_process_rule.mode)
} }
@ -525,7 +532,6 @@ const StepTwo = ({
onSuccess(data) { onSuccess(data) {
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
updateResultCache && updateResultCache(data) updateResultCache && updateResultCache(data)
// eslint-disable-next-line @typescript-eslint/no-use-before-define
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string) updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
}, },
}, },
@ -545,14 +551,6 @@ const StepTwo = ({
isSetting && onSave && onSave() isSetting && onSave && onSave()
} }
const changeToEconomicalType = () => {
if (docForm !== ChunkingMode.text)
return
if (!hasSetIndexType)
setIndexType(IndexingType.ECONOMICAL)
}
useEffect(() => { useEffect(() => {
// fetch rules // fetch rules
if (!isSetting) { if (!isSetting) {
@ -574,18 +572,6 @@ const StepTwo = ({
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL) setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
}, [isAPIKeySet, indexingType, datasetId]) }, [isAPIKeySet, indexingType, datasetId])
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.provider.provider,
reranking_model_name: rerankDefaultModel?.model,
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
} as RetrievalConfig)
const economyDomRef = useRef<HTMLDivElement>(null) const economyDomRef = useRef<HTMLDivElement>(null)
const isHoveringEconomy = useHover(economyDomRef) const isHoveringEconomy = useHover(economyDomRef)
@ -946,6 +932,7 @@ const StepTwo = ({
<div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div> <div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div>
<ModelSelector <ModelSelector
readonly={!!datasetId} readonly={!!datasetId}
triggerClassName={datasetId ? 'opacity-50' : ''}
defaultModel={embeddingModel} defaultModel={embeddingModel}
modelList={embeddingModelList} modelList={embeddingModelList}
onSelect={(model: DefaultModel) => { onSelect={(model: DefaultModel) => {
@ -984,12 +971,14 @@ const StepTwo = ({
getIndexing_technique() === IndexingType.QUALIFIED getIndexing_technique() === IndexingType.QUALIFIED
? ( ? (
<RetrievalMethodConfig <RetrievalMethodConfig
disabled={!!datasetId}
value={retrievalConfig} value={retrievalConfig}
onChange={setRetrievalConfig} onChange={setRetrievalConfig}
/> />
) )
: ( : (
<EconomicalRetrievalMethodConfig <EconomicalRetrievalMethodConfig
disabled={!!datasetId}
value={retrievalConfig} value={retrievalConfig}
onChange={setRetrievalConfig} onChange={setRetrievalConfig}
/> />
@ -1010,7 +999,7 @@ const StepTwo = ({
) )
: ( : (
<div className='flex items-center mt-8 py-2'> <div className='flex items-center mt-8 py-2'>
<Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button> {!datasetId && <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>}
<Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button> <Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button>
</div> </div>
)} )}
@ -1081,11 +1070,11 @@ const StepTwo = ({
} }
{ {
currentDocForm !== ChunkingMode.qa currentDocForm !== ChunkingMode.qa
&& <Badge text={t( && <Badge text={t(
'datasetCreation.stepTwo.previewChunkCount', { 'datasetCreation.stepTwo.previewChunkCount', {
count: estimate?.total_segments || 0, count: estimate?.total_segments || 0,
}) as string} }) as string}
/> />
} }
</div> </div>
</PreviewHeader>} </PreviewHeader>}
@ -1117,6 +1106,9 @@ const StepTwo = ({
{currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && ( {currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && (
estimate?.preview?.map((item, index) => { estimate?.preview?.map((item, index) => {
const indexForLabel = index + 1 const indexForLabel = index + 1
const childChunks = parentChildConfig.chunkForContext === 'full-doc'
? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH)
: item.child_chunks
return ( return (
<ChunkContainer <ChunkContainer
key={item.content} key={item.content}
@ -1124,7 +1116,7 @@ const StepTwo = ({
characterCount={item.content.length} characterCount={item.content.length}
> >
<FormattedText> <FormattedText>
{item.child_chunks.map((child, index) => { {childChunks.map((child, index) => {
const indexForLabel = index + 1 const indexForLabel = index + 1
return ( return (
<PreviewSlice <PreviewSlice

View File

@ -4,7 +4,7 @@ import classNames from '@/utils/classnames'
const TriangleArrow: FC<ComponentProps<'svg'>> = props => ( const TriangleArrow: FC<ComponentProps<'svg'>> = props => (
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}> <svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}>
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor"/> <path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor" />
</svg> </svg>
) )
@ -65,7 +65,7 @@ export const OptionCard: FC<OptionCardProps> = forwardRef((props, ref) => {
(isActive && !noHighlight) (isActive && !noHighlight)
? 'border-[1.5px] border-components-option-card-option-selected-border' ? 'border-[1.5px] border-components-option-card-option-selected-border'
: 'border border-components-option-card-option-border', : 'border border-components-option-card-option-border',
disabled && 'opacity-50 cursor-not-allowed', disabled && 'opacity-50 pointer-events-none',
className, className,
)} )}
style={{ style={{

View File

@ -32,6 +32,9 @@ import Checkbox from '@/app/components/base/checkbox'
import { import {
useChildSegmentList, useChildSegmentList,
useChildSegmentListKey, useChildSegmentListKey,
useChunkListAllKey,
useChunkListDisabledKey,
useChunkListEnabledKey,
useDeleteChildSegment, useDeleteChildSegment,
useDeleteSegment, useDeleteSegment,
useDisableSegment, useDisableSegment,
@ -156,18 +159,18 @@ const Completed: FC<ICompletedProps> = ({
page: isFullDocMode ? 1 : currentPage, page: isFullDocMode ? 1 : currentPage,
limit: isFullDocMode ? 10 : limit, limit: isFullDocMode ? 10 : limit,
keyword: isFullDocMode ? '' : searchValue, keyword: isFullDocMode ? '' : searchValue,
enabled: selectedStatus === 'all' ? 'all' : !!selectedStatus, enabled: selectedStatus,
}, },
}, },
currentPage === 0,
) )
const invalidSegmentList = useInvalid(useSegmentListKey) const invalidSegmentList = useInvalid(useSegmentListKey)
useEffect(() => { useEffect(() => {
if (segmentListData) { if (segmentListData) {
setSegments(segmentListData.data || []) setSegments(segmentListData.data || [])
if (segmentListData.total_pages < currentPage) const totalPages = segmentListData.total_pages
setCurrentPage(segmentListData.total_pages) if (totalPages < currentPage)
setCurrentPage(totalPages === 0 ? 1 : totalPages)
} }
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [segmentListData]) }, [segmentListData])
@ -185,12 +188,12 @@ const Completed: FC<ICompletedProps> = ({
documentId, documentId,
segmentId: segments[0]?.id || '', segmentId: segments[0]?.id || '',
params: { params: {
page: currentPage, page: currentPage === 0 ? 1 : currentPage,
limit, limit,
keyword: searchValue, keyword: searchValue,
}, },
}, },
!isFullDocMode || segments.length === 0 || currentPage === 0, !isFullDocMode || segments.length === 0,
) )
const invalidChildSegmentList = useInvalid(useChildSegmentListKey) const invalidChildSegmentList = useInvalid(useChildSegmentListKey)
@ -204,21 +207,20 @@ const Completed: FC<ICompletedProps> = ({
useEffect(() => { useEffect(() => {
if (childChunkListData) { if (childChunkListData) {
setChildSegments(childChunkListData.data || []) setChildSegments(childChunkListData.data || [])
if (childChunkListData.total_pages < currentPage) const totalPages = childChunkListData.total_pages
setCurrentPage(childChunkListData.total_pages) if (totalPages < currentPage)
setCurrentPage(totalPages === 0 ? 1 : totalPages)
} }
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [childChunkListData]) }, [childChunkListData])
const resetList = useCallback(() => { const resetList = useCallback(() => {
setSegments([])
setSelectedSegmentIds([]) setSelectedSegmentIds([])
invalidSegmentList() invalidSegmentList()
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []) }, [])
const resetChildList = useCallback(() => { const resetChildList = useCallback(() => {
setChildSegments([])
invalidChildSegmentList() invalidChildSegmentList()
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []) }, [])
@ -232,8 +234,32 @@ const Completed: FC<ICompletedProps> = ({
setFullScreen(false) setFullScreen(false)
}, []) }, [])
const onCloseNewSegmentModal = useCallback(() => {
onNewSegmentModalChange(false)
setFullScreen(false)
}, [onNewSegmentModalChange])
const onCloseNewChildChunkModal = useCallback(() => {
setShowNewChildSegmentModal(false)
setFullScreen(false)
}, [])
const { mutateAsync: enableSegment } = useEnableSegment() const { mutateAsync: enableSegment } = useEnableSegment()
const { mutateAsync: disableSegment } = useDisableSegment() const { mutateAsync: disableSegment } = useDisableSegment()
const invalidChunkListAll = useInvalid(useChunkListAllKey)
const invalidChunkListEnabled = useInvalid(useChunkListEnabledKey)
const invalidChunkListDisabled = useInvalid(useChunkListDisabledKey)
const refreshChunkListWithStatusChanged = () => {
switch (selectedStatus) {
case 'all':
invalidChunkListDisabled()
invalidChunkListEnabled()
break
default:
invalidSegmentList()
}
}
const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => { const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => {
const operationApi = enable ? enableSegment : disableSegment const operationApi = enable ? enableSegment : disableSegment
@ -245,6 +271,7 @@ const Completed: FC<ICompletedProps> = ({
seg.enabled = enable seg.enabled = enable
} }
setSegments([...segments]) setSegments([...segments])
refreshChunkListWithStatusChanged()
}, },
onError: () => { onError: () => {
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
@ -271,6 +298,23 @@ const Completed: FC<ICompletedProps> = ({
const { mutateAsync: updateSegment } = useUpdateSegment() const { mutateAsync: updateSegment } = useUpdateSegment()
const refreshChunkListDataWithDetailChanged = () => {
switch (selectedStatus) {
case 'all':
invalidChunkListDisabled()
invalidChunkListEnabled()
break
case true:
invalidChunkListAll()
invalidChunkListDisabled()
break
case false:
invalidChunkListAll()
invalidChunkListEnabled()
break
}
}
const handleUpdateSegment = useCallback(async ( const handleUpdateSegment = useCallback(async (
segmentId: string, segmentId: string,
question: string, question: string,
@ -320,6 +364,7 @@ const Completed: FC<ICompletedProps> = ({
} }
} }
setSegments([...segments]) setSegments([...segments])
refreshChunkListDataWithDetailChanged()
eventEmitter?.emit('update-segment-success') eventEmitter?.emit('update-segment-success')
}, },
onSettled() { onSettled() {
@ -432,6 +477,7 @@ const Completed: FC<ICompletedProps> = ({
seg.child_chunks?.push(newChildChunk!) seg.child_chunks?.push(newChildChunk!)
} }
setSegments([...segments]) setSegments([...segments])
refreshChunkListDataWithDetailChanged()
} }
else { else {
resetChildList() resetChildList()
@ -496,17 +542,10 @@ const Completed: FC<ICompletedProps> = ({
} }
} }
setSegments([...segments]) setSegments([...segments])
refreshChunkListDataWithDetailChanged()
} }
else { else {
for (const childSeg of childSegments) { resetChildList()
if (childSeg.id === childChunkId) {
childSeg.content = res.data.content
childSeg.type = res.data.type
childSeg.word_count = res.data.word_count
childSeg.updated_at = res.data.updated_at
}
}
setChildSegments([...childSegments])
} }
}, },
onSettled: () => { onSettled: () => {
@ -544,12 +583,13 @@ const Completed: FC<ICompletedProps> = ({
<SimpleSelect <SimpleSelect
onSelect={onChangeStatus} onSelect={onChangeStatus}
items={statusList.current} items={statusList.current}
defaultValue={'all'} defaultValue={selectedStatus === 'all' ? 'all' : selectedStatus ? 1 : 0}
className={s.select} className={s.select}
wrapperClassName='h-fit mr-2' wrapperClassName='h-fit mr-2'
optionWrapClassName='w-[160px]' optionWrapClassName='w-[160px]'
optionClassName='p-0' optionClassName='p-0'
renderOption={({ item, selected }) => <StatusItem item={item} selected={selected} />} renderOption={({ item, selected }) => <StatusItem item={item} selected={selected} />}
notClearable
/> />
<Input <Input
showLeftIcon showLeftIcon
@ -623,6 +663,7 @@ const Completed: FC<ICompletedProps> = ({
<FullScreenDrawer <FullScreenDrawer
isOpen={currSegment.showModal} isOpen={currSegment.showModal}
fullScreen={fullScreen} fullScreen={fullScreen}
onClose={onCloseSegmentDetail}
> >
<SegmentDetail <SegmentDetail
segInfo={currSegment.segInfo ?? { id: '' }} segInfo={currSegment.segInfo ?? { id: '' }}
@ -636,13 +677,11 @@ const Completed: FC<ICompletedProps> = ({
<FullScreenDrawer <FullScreenDrawer
isOpen={showNewSegmentModal} isOpen={showNewSegmentModal}
fullScreen={fullScreen} fullScreen={fullScreen}
onClose={onCloseNewSegmentModal}
> >
<NewSegment <NewSegment
docForm={docForm} docForm={docForm}
onCancel={() => { onCancel={onCloseNewSegmentModal}
onNewSegmentModalChange(false)
setFullScreen(false)
}}
onSave={resetList} onSave={resetList}
viewNewlyAddedChunk={viewNewlyAddedChunk} viewNewlyAddedChunk={viewNewlyAddedChunk}
/> />
@ -651,6 +690,7 @@ const Completed: FC<ICompletedProps> = ({
<FullScreenDrawer <FullScreenDrawer
isOpen={currChildChunk.showModal} isOpen={currChildChunk.showModal}
fullScreen={fullScreen} fullScreen={fullScreen}
onClose={onCloseChildSegmentDetail}
> >
<ChildSegmentDetail <ChildSegmentDetail
chunkId={currChunkId} chunkId={currChunkId}
@ -664,13 +704,11 @@ const Completed: FC<ICompletedProps> = ({
<FullScreenDrawer <FullScreenDrawer
isOpen={showNewChildSegmentModal} isOpen={showNewChildSegmentModal}
fullScreen={fullScreen} fullScreen={fullScreen}
onClose={onCloseNewChildChunkModal}
> >
<NewChildSegment <NewChildSegment
chunkId={currChunkId} chunkId={currChunkId}
onCancel={() => { onCancel={onCloseNewChildChunkModal}
setShowNewChildSegmentModal(false)
setFullScreen(false)
}}
onSave={onSaveNewChildChunk} onSave={onSaveNewChildChunk}
viewNewlyAddedChildChunk={viewNewlyAddedChildChunk} viewNewlyAddedChildChunk={viewNewlyAddedChildChunk}
/> />

Some files were not shown because too many files have changed in this diff Show More