mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 02:45:52 +08:00
Merge branch 'main' into fix/chore-fix
This commit is contained in:
commit
fb309462ad
@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# Refresh token expiration time in days
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
|
@ -14,7 +14,10 @@ if is_db_command():
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
if os.environ.get("FLASK_DEBUG", "False") != "True":
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey # type: ignore
|
||||
|
||||
# gevent
|
||||
|
@ -546,6 +546,11 @@ class AuthConfig(BaseSettings):
|
||||
default=60,
|
||||
)
|
||||
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
|
||||
description="Expiration time for refresh tokens in days",
|
||||
default=30,
|
||||
)
|
||||
|
||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||
default=86400,
|
||||
@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings):
|
||||
default=4000,
|
||||
)
|
||||
|
||||
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
|
||||
description="Maximum number of child chunks to preview",
|
||||
default=50,
|
||||
)
|
||||
|
||||
|
||||
class MultiModalTransferConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||
|
@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
|
||||
description="Name of the Milvus database to connect to (default is 'default')",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
|
||||
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
|
||||
"older versions",
|
||||
default=True,
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.14.2",
|
||||
default="0.15.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
@ -57,12 +57,13 @@ class AppListApi(Resource):
|
||||
)
|
||||
parser.add_argument("name", type=str, location="args", required=False)
|
||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
|
@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
@ -273,8 +273,7 @@ FROM
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.override_model_configs IS NULL
|
||||
AND c.app_id = :app_id"""
|
||||
c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.LINDORM
|
||||
|
@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource):
|
||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
|
@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
|
@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
|
@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
@ -74,7 +73,7 @@ class CompletionApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@ -133,7 +132,7 @@ class ChatApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
@ -8,6 +8,8 @@ from flask import current_app, request
|
||||
from flask_login import user_logged_in # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope=None):
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
api_token = (
|
||||
db.session.query(ApiToken)
|
||||
.filter(
|
||||
ApiToken.token == auth_token,
|
||||
ApiToken.type == scope,
|
||||
current_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = result.scalar_one_or_none()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
if not api_token:
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
return api_token
|
||||
|
||||
|
@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
|
@ -14,7 +14,11 @@ from controllers.web.error import (
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
|
@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from extensions.ext_database import db
|
||||
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
@ -68,24 +68,17 @@ from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream: bool,
|
||||
dialogue_count: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self._wip_workflow_agent_logs = {}
|
||||
self._message_cycle_manager = MessageCycleManage(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
self._conversation_name_generate_thread: Thread | None = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self._workflow_run_id = ""
|
||||
self._workflow_run_id: str = ""
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
err = self._base_task_pipeline._handle_error(
|
||||
event=event, session=session, message_id=self._message_id
|
||||
)
|
||||
session.commit()
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
message.workflow_run_id = workflow_run.id
|
||||
workflow_start_resp = self._workflow_start_to_stream_response(
|
||||
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_node_retry_to_stream_response(
|
||||
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
node_start_resp = self._workflow_node_start_to_stream_response(
|
||||
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
self._recorded_files.extend(
|
||||
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
|
||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -368,10 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session, event=event
|
||||
)
|
||||
|
||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -385,13 +392,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
@ -399,13 +410,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
@ -413,9 +428,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -427,9 +444,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -441,9 +460,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -458,8 +479,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -470,21 +491,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -495,21 +518,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -521,20 +546,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
err = self._base_task_pipeline._handle_error(
|
||||
event=err_event, session=session, message_id=self._message_id
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
if self._workflow_run_id and graph_runtime_state:
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -545,7 +572,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -559,18 +586,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
self._message_cycle_manager._handle_retriever_resources(event)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
session.commit()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
self._message_cycle_manager._handle_annotation_reply(event)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
@ -591,23 +618,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(
|
||||
yield self._message_cycle_manager._message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||
self._task_state.answer
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
# Save message
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||
session.commit()
|
||||
|
||||
@ -627,7 +658,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
@ -691,20 +722,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
if self._base_task_pipeline._output_moderation_handler:
|
||||
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
|
@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -59,7 +59,6 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
@ -67,16 +66,11 @@ from models.workflow import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._workflow_run_id = ""
|
||||
|
||||
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
:return:
|
||||
"""
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event=event)
|
||||
yield self._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
created_by_role=self._created_by_role,
|
||||
)
|
||||
self._workflow_run_id = workflow_run.id
|
||||
start_resp = self._workflow_start_to_stream_response(
|
||||
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_start_response = self._workflow_node_start_to_stream_response(
|
||||
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session,
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
||||
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
workflow_run_id: Optional[str] = None
|
||||
workflow_run_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
PingStreamResponse,
|
||||
TaskState,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
|
||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: TaskState
|
||||
_application_generate_entity: AppGenerateEntity
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._start_at = time.perf_counter()
|
||||
|
@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManage:
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
|
@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -60,13 +59,20 @@ from models.workflow import (
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
|
||||
from .exc import WorkflowRunNotFoundError
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||
) -> None:
|
||||
self._workflow_run: WorkflowRun | None = None
|
||||
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
|
||||
def _handle_workflow_run_start(
|
||||
self,
|
||||
@ -104,7 +110,8 @@ class WorkflowCycleManage:
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = workflow_run_id
|
||||
@ -241,7 +248,7 @@ class WorkflowCycleManage:
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
@ -249,15 +256,18 @@ class WorkflowCycleManage:
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
|
||||
running_workflow_node_executions = session.scalars(stmt).all()
|
||||
ids = session.scalars(stmt).all()
|
||||
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
||||
running_workflow_node_executions = [
|
||||
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
||||
]
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
finish_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = finish_at
|
||||
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
|
||||
workflow_node_execution.finished_at = now
|
||||
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@ -299,6 +309,8 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(
|
||||
@ -325,6 +337,7 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
@ -364,6 +377,7 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
@ -415,6 +429,8 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
@ -811,25 +827,23 @@ class WorkflowCycleManage:
|
||||
return None
|
||||
|
||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||
cached_workflow_run = self._workflow_run
|
||||
cached_workflow_run = session.merge(cached_workflow_run)
|
||||
return cached_workflow_run
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||
self._workflow_run = workflow_run
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
|
||||
workflow_node_execution = session.scalar(stmt)
|
||||
if not workflow_node_execution:
|
||||
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
|
||||
|
||||
return workflow_node_execution
|
||||
if node_execution_id not in self._workflow_node_executions:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return cached_workflow_node_execution
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
|
@ -1,9 +1,47 @@
|
||||
import tiktoken
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||
"""
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
encoding = tiktoken.encoding_for_model("gpt2")
|
||||
tiktoken_vec = encoding.encode(text)
|
||||
return len(tiktoken_vec)
|
||||
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||
#
|
||||
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
# result = future.result()
|
||||
# return cast(int, result)
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
# Try to use tiktoken to get the tokenizer because it is faster
|
||||
#
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||
except Exception:
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
|
||||
return _tokenizer
|
||||
|
@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||
|
||||
|
@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(ids=ids)
|
||||
|
||||
|
@ -0,0 +1,104 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchConfig,
|
||||
ElasticSearchVector,
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchJaVector(ElasticSearchVector):
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
dim = len(embeddings[0])
|
||||
settings = {
|
||||
"analysis": {
|
||||
"analyzer": {
|
||||
"ja_analyzer": {
|
||||
"type": "custom",
|
||||
"char_filter": [
|
||||
"icu_normalizer",
|
||||
"kuromoji_iteration_mark",
|
||||
],
|
||||
"tokenizer": "kuromoji_tokenizer",
|
||||
"filter": [
|
||||
"kuromoji_baseform",
|
||||
"kuromoji_part_of_speech",
|
||||
"ja_stop",
|
||||
"kuromoji_number",
|
||||
"kuromoji_stemmer",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {
|
||||
"type": "text",
|
||||
"analyzer": "ja_analyzer",
|
||||
"search_analyzer": "ja_analyzer",
|
||||
},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return ElasticSearchJaVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
),
|
||||
attributes=[],
|
||||
)
|
@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
|
||||
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
|
@ -6,6 +6,8 @@ class Field(Enum):
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
# Sparse Vector aims to support full text search
|
||||
SPARSE_VECTOR = "sparse_vector"
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = "id"
|
||||
DOC_ID = "metadata.doc_id"
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException # type: ignore
|
||||
from pymilvus.milvus_client import IndexParams # type: ignore
|
||||
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
uri: str
|
||||
token: Optional[str] = None
|
||||
user: str
|
||||
password: str
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
"""
|
||||
Configuration class for Milvus connection.
|
||||
"""
|
||||
|
||||
uri: str # Milvus server URI
|
||||
token: Optional[str] = None # Optional token for authentication
|
||||
user: str # Username for authentication
|
||||
password: str # Password for authentication
|
||||
batch_size: int = 100 # Batch size for operations
|
||||
database: str = "default" # Database name
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""
|
||||
Validate the configuration values.
|
||||
Raises ValueError if required fields are missing.
|
||||
"""
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
if not values.get("user"):
|
||||
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
"""
|
||||
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||
"""
|
||||
return {
|
||||
"uri": self.uri,
|
||||
"token": self.token,
|
||||
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
"""
|
||||
Milvus vector storage implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = "Session"
|
||||
self._fields: list[str] = []
|
||||
self._consistency_level = "Session" # Consistency level for Milvus operations
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
"""
|
||||
Check if the current Milvus version supports hybrid search.
|
||||
Returns True if the version is >= 2.5.0, otherwise False.
|
||||
"""
|
||||
if not self._client_config.enable_hybrid_search:
|
||||
return False
|
||||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||
return False
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""
|
||||
Get the type of vector storage (Milvus).
|
||||
"""
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Create a collection and add texts with embeddings.
|
||||
"""
|
||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Add texts and their embeddings to the collection.
|
||||
"""
|
||||
insert_dict_list = []
|
||||
for i in range(len(documents)):
|
||||
insert_dict = {
|
||||
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
|
||||
# function will automatically convert the native text into a sparse vector for us.
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
total_count = len(insert_dict_list)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
for i in range(0, total_count, 1000):
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
# Insert into the collection.
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
try:
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
|
||||
return pks
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Get document IDs by metadata field key and value.
|
||||
"""
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
||||
)
|
||||
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Delete documents by metadata field key and value.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""
|
||||
Delete documents by their IDs.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
||||
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete the entire collection.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
self._client.drop_collection(self._collection_name, None)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""
|
||||
Check if a text with the given ID exists in the collection.
|
||||
"""
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
return False
|
||||
|
||||
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def field_exists(self, field: str) -> bool:
|
||||
"""
|
||||
Check if a field exists in the collection.
|
||||
"""
|
||||
return field in self._fields
|
||||
|
||||
def _process_search_results(
|
||||
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Common method to process search results
|
||||
|
||||
:param results: Search results
|
||||
:param output_fields: Fields to be output
|
||||
:param score_threshold: Score threshold for filtering
|
||||
:return: List of documents
|
||||
"""
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(output_fields[1], {})
|
||||
metadata["score"] = result["distance"]
|
||||
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
# Set search parameters.
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
"""
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
"""
|
||||
Search for documents by full-text search (if hybrid search is enabled).
|
||||
"""
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
return []
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query],
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Create a new collection in Milvus with the specified schema and index parameters.
|
||||
"""
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
|
||||
return
|
||||
# Grab the existing collection if it exists
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
|
||||
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
||||
|
||||
# Determine embedding dim
|
||||
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
|
||||
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
Field.CONTENT_KEY.value,
|
||||
DataType.VARCHAR,
|
||||
max_length=65_535,
|
||||
enable_analyzer=self._hybrid_search_enabled,
|
||||
)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
# Create custom function to support text to sparse vector by BM25
|
||||
if self._hybrid_search_enabled:
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=[Field.CONTENT_KEY.value],
|
||||
output_field_names=[Field.SPARSE_VECTOR.value],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_function)
|
||||
|
||||
for x in schema.fields:
|
||||
self._fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
|
||||
index_params_obj = IndexParams()
|
||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
index_params_obj.add_index(
|
||||
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
|
||||
)
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection(
|
||||
collection_name=collection_name,
|
||||
collection_name=self._collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level,
|
||||
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
"""
|
||||
Initialize and return a Milvus client.
|
||||
"""
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
return client
|
||||
|
||||
|
||||
class MilvusVectorFactory(AbstractVectorFactory):
|
||||
"""
|
||||
Factory class for creating MilvusVector instances.
|
||||
"""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||
"""
|
||||
Initialize a MilvusVector instance for the given dataset.
|
||||
"""
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
user=dify_config.MILVUS_USER or "",
|
||||
password=dify_config.MILVUS_PASSWORD or "",
|
||||
database=dify_config.MILVUS_DATABASE or "",
|
||||
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
||||
),
|
||||
)
|
||||
|
@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
|
||||
return bool(cur.rowcount != 0)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
|
@ -167,6 +167,8 @@ class OracleVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
|
||||
|
@ -129,6 +129,11 @@ class PGVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
|
||||
|
@ -140,6 +140,8 @@ class TencentVector(BaseVector):
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
|
@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
if not tidb_auth_binding:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
|
@ -90,6 +90,12 @@ class Vector:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.ELASTICSEARCH_JA:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||
ElasticSearchJaVectorFactory,
|
||||
)
|
||||
|
||||
return ElasticSearchJaVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
|
@ -16,6 +16,7 @@ class VectorType(StrEnum):
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
ELASTICSEARCH_JA = "elasticsearch-ja"
|
||||
LINDORM = "lindorm"
|
||||
COUCHBASE = "couchbase"
|
||||
BAIDU = "baidu"
|
||||
|
@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ""
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode("utf-8"))
|
||||
if not plaintext_file_exists and self._file_cache_key:
|
||||
storage.save(self._file_cache_key, text.encode("utf-8"))
|
||||
|
||||
return documents
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_nodes = self._split_child_nodes(
|
||||
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||
)
|
||||
if kwargs.get("preview"):
|
||||
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
|
||||
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
|
||||
|
||||
document.children = child_nodes
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
|
@ -212,8 +212,23 @@ class ApiTool(Tool):
|
||||
else:
|
||||
body = body
|
||||
|
||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response: httpx.Response = getattr(ssrf_proxy, method)(
|
||||
if method in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
|
@ -2,14 +2,18 @@ import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import tempfile
|
||||
from typing import cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypdfium2 # type: ignore
|
||||
import yaml # type: ignore
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: DocumentExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
"""Extract text from a file based on its MIME type."""
|
||||
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
text = []
|
||||
# Process paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text.append(paragraph.text)
|
||||
|
||||
# Process tables
|
||||
for table in doc.tables:
|
||||
# Table header
|
||||
try:
|
||||
# table maybe cause errors so ignore it.
|
||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
||||
# Keep track of paragraph and table positions
|
||||
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||
|
||||
# Process paragraphs and tables
|
||||
for i, paragraph in enumerate(doc.paragraphs):
|
||||
if paragraph.text.strip():
|
||||
content_items.append((i, "paragraph", paragraph))
|
||||
|
||||
for i, table in enumerate(doc.tables):
|
||||
content_items.append((i, "table", table))
|
||||
|
||||
# Sort content items based on their original position
|
||||
content_items.sort(key=operator.itemgetter(0))
|
||||
|
||||
# Process sorted content
|
||||
for _, item_type, item in content_items:
|
||||
if item_type == "paragraph":
|
||||
if isinstance(item, Table):
|
||||
continue
|
||||
text.append(item.text)
|
||||
elif item_type == "table":
|
||||
# Process tables
|
||||
if not isinstance(item, Table):
|
||||
continue
|
||||
try:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in table.rows:
|
||||
for row in item.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
||||
for row in table.rows[1:]:
|
||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
||||
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
|
||||
markdown_table = f"| {' | '.join(cell_texts)} |\n"
|
||||
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
|
||||
|
||||
for row in item.rows[1:]:
|
||||
# Replace newlines with <br> in each cell
|
||||
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
|
||||
markdown_table += "| " + " | ".join(row_cells) + " |\n"
|
||||
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
|
@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
method: Literal["get", "post", "put", "patch", "delete", "head"]
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"patch",
|
||||
"delete",
|
||||
"head",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
]
|
||||
url: str
|
||||
authorization: HttpRequestNodeAuthorization
|
||||
headers: str
|
||||
|
@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
|
||||
|
||||
|
||||
class Executor:
|
||||
method: Literal["get", "head", "post", "put", "delete", "patch"]
|
||||
method: Literal[
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
]
|
||||
url: str
|
||||
params: list[tuple[str, str]] | None
|
||||
content: str | bytes | None
|
||||
@ -67,12 +82,6 @@ class Executor:
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
||||
# check if node_data.url is a valid URL
|
||||
if not node_data.url:
|
||||
raise InvalidURLError("url is required")
|
||||
if not node_data.url.startswith(("http://", "https://")):
|
||||
raise InvalidURLError("url should start with http:// or https://")
|
||||
|
||||
self.url: str = node_data.url
|
||||
self.method = node_data.method
|
||||
self.auth = node_data.authorization
|
||||
@ -99,6 +108,12 @@ class Executor:
|
||||
def _init_url(self):
|
||||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||
|
||||
# check if url is a valid URL
|
||||
if not self.url:
|
||||
raise InvalidURLError("url is required")
|
||||
if not self.url.startswith(("http://", "https://")):
|
||||
raise InvalidURLError("url should start with http:// or https://")
|
||||
|
||||
def _init_params(self):
|
||||
"""
|
||||
Almost same as _init_headers(), difference:
|
||||
@ -158,7 +173,10 @@ class Executor:
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("json body type should have exactly one item")
|
||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
try:
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
|
||||
self.json = json_object
|
||||
# self.json = self._parse_object_contains_variables(json_object)
|
||||
case "binary":
|
||||
@ -246,7 +264,22 @@ class Executor:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
if self.method not in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
@ -263,7 +296,7 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
|
@ -340,6 +340,10 @@ class WorkflowEntry:
|
||||
):
|
||||
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
||||
|
||||
# environment variable already exist in variable pool, not from user inputs
|
||||
if variable_pool.get(variable_selector):
|
||||
continue
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
@ -33,6 +33,7 @@ else
|
||||
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
||||
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
||||
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
||||
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
|
||||
--timeout ${GUNICORN_TIMEOUT:-200} \
|
||||
app:app
|
||||
fi
|
||||
|
@ -46,7 +46,7 @@ def init_app(app: DifyApp):
|
||||
timezone = pytz.timezone(log_tz)
|
||||
|
||||
def time_converter(seconds):
|
||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
|
@ -158,7 +158,7 @@ def _build_from_remote_url(
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
) -> File:
|
||||
url = mapping.get("url")
|
||||
url = mapping.get("url") or mapping.get("remote_url")
|
||||
if not url:
|
||||
raise ValueError("Invalid file url")
|
||||
|
||||
|
@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource):
|
||||
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Error fetching block parent page ID: {response_json.message}")
|
||||
message = response_json.get("message", "unknown error")
|
||||
raise ValueError(f"Error fetching block parent page ID: {message}")
|
||||
parent = response_json["parent"]
|
||||
parent_type = parent["type"]
|
||||
if parent_type == "block_id":
|
||||
|
@ -0,0 +1,41 @@
|
||||
"""change workflow_runs.total_tokens to bigint
|
||||
|
||||
Revision ID: a91b476a53de
|
||||
Revises: 923752d42eb6
|
||||
Create Date: 2025-01-01 20:00:01.207369
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a91b476a53de'
|
||||
down_revision = '923752d42eb6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||
batch_op.alter_column('total_tokens',
|
||||
existing_type=sa.INTEGER(),
|
||||
type_=sa.BigInteger(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text('0'))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||
batch_op.alter_column('total_tokens',
|
||||
existing_type=sa.BigInteger(),
|
||||
type_=sa.INTEGER(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text('0'))
|
||||
|
||||
# ### end Alembic commands ###
|
@ -415,8 +415,8 @@ class WorkflowRun(Base):
|
||||
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
|
||||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
|
||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
total_steps = db.Column(db.Integer, server_default=db.text("0"))
|
||||
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
|
2397
api/poetry.lock
generated
2397
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -71,7 +71,7 @@ pyjwt = "~2.8.0"
|
||||
pypdfium2 = "~4.30.0"
|
||||
python = ">=3.11,<3.13"
|
||||
python-docx = "~1.1.0"
|
||||
python-dotenv = "1.0.0"
|
||||
python-dotenv = "1.0.1"
|
||||
pyyaml = "~6.0.1"
|
||||
readabilipy = "0.2.0"
|
||||
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
||||
@ -82,7 +82,7 @@ scikit-learn = "~1.5.1"
|
||||
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
||||
sqlalchemy = "~2.0.29"
|
||||
starlette = "0.41.0"
|
||||
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
|
||||
tencentcloud-sdk-python-hunyuan = "~3.0.1294"
|
||||
tiktoken = "~0.8.0"
|
||||
tokenizers = "~0.15.0"
|
||||
transformers = "~4.35.0"
|
||||
@ -92,7 +92,7 @@ validators = "0.21.0"
|
||||
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
|
||||
websocket-client = "~1.7.0"
|
||||
xinference-client = "0.15.2"
|
||||
yarl = "~1.9.4"
|
||||
yarl = "~1.18.3"
|
||||
youtube-transcript-api = "~0.6.2"
|
||||
zhipuai = "~2.1.5"
|
||||
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
||||
@ -157,7 +157,7 @@ opensearch-py = "2.4.0"
|
||||
oracledb = "~2.2.1"
|
||||
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
||||
pgvector = "0.2.5"
|
||||
pymilvus = "~2.4.4"
|
||||
pymilvus = "~2.5.0"
|
||||
pymochow = "1.3.1"
|
||||
pyobvector = "~0.1.6"
|
||||
qdrant-client = "1.7.3"
|
||||
|
@ -168,23 +168,6 @@ def clean_unused_datasets_task():
|
||||
else:
|
||||
plan = plan_cache.decode()
|
||||
if plan == "sandbox":
|
||||
# add auto disable log
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for document in documents:
|
||||
dataset_auto_disable_log = DatasetAutoDisableLog(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
)
|
||||
db.session.add(dataset_auto_disable_log)
|
||||
# remove index
|
||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None)
|
||||
|
@ -66,7 +66,7 @@ class TokenPair(BaseModel):
|
||||
|
||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
class AccountService:
|
||||
|
@ -2,6 +2,7 @@ import logging
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
@ -124,7 +125,7 @@ class AppDslService:
|
||||
raise ValueError(f"Invalid import_mode: {import_mode}")
|
||||
|
||||
# Get YAML content
|
||||
content: bytes | str = b""
|
||||
content: str = ""
|
||||
if mode == ImportMode.YAML_URL:
|
||||
if not yaml_url:
|
||||
return Import(
|
||||
@ -133,13 +134,17 @@ class AppDslService:
|
||||
error="yaml_url is required when import_mode is yaml-url",
|
||||
)
|
||||
try:
|
||||
# tricky way to handle url from github to github raw url
|
||||
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")):
|
||||
parsed_url = urlparse(yaml_url)
|
||||
if (
|
||||
parsed_url.scheme == "https"
|
||||
and parsed_url.netloc == "github.com"
|
||||
and parsed_url.path.endswith((".yml", ".yaml"))
|
||||
):
|
||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||
yaml_url = yaml_url.replace("/blob/", "/")
|
||||
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
content = response.content.decode()
|
||||
|
||||
if len(content) > DSL_MAX_SIZE:
|
||||
return Import(
|
||||
|
@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
|
||||
|
||||
|
||||
class AppService:
|
||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id
|
||||
:param args: request args
|
||||
:return:
|
||||
@ -44,6 +45,8 @@ class AppService:
|
||||
elif args["mode"] == "channel":
|
||||
filters.append(App.mode == AppMode.CHANNEL.value)
|
||||
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
if args.get("name"):
|
||||
name = args["name"][:30]
|
||||
filters.append(App.name.ilike(f"%{name}%"))
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import httpx
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
@ -17,7 +17,6 @@ class BillingService:
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
@ -47,12 +46,13 @@ class BillingService:
|
||||
retry=retry_if_exception_type(httpx.RequestError),
|
||||
reraise=True,
|
||||
)
|
||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
||||
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
||||
|
||||
if method == "GET" and response.status_code != httpx.codes.OK:
|
||||
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
|
@ -86,7 +86,7 @@ class DatasetService:
|
||||
else:
|
||||
return [], 0
|
||||
else:
|
||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
||||
if user.current_role != TenantAccountRole.OWNER:
|
||||
# show all datasets that the user has permission to access
|
||||
if permitted_dataset_ids:
|
||||
query = query.filter(
|
||||
@ -382,7 +382,7 @@ class DatasetService:
|
||||
if dataset.tenant_id != user.current_tenant_id:
|
||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
||||
if user.current_role != TenantAccountRole.OWNER:
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
@ -404,7 +404,7 @@ class DatasetService:
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
||||
if user.current_role != TenantAccountRole.OWNER:
|
||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
|
||||
if dataset.created_by != user.id:
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
@ -434,6 +434,12 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
|
||||
return {
|
||||
"document_ids": [],
|
||||
"count": 0,
|
||||
}
|
||||
# get recent 30 days auto disable logs
|
||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
||||
@ -786,13 +792,19 @@ class DocumentService:
|
||||
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
dataset_embedding_model = knowledge_config.embedding_model
|
||||
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||
else:
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset_embedding_model = embedding_model.model
|
||||
dataset_embedding_model_provider = embedding_model.provider
|
||||
dataset.embedding_model = dataset_embedding_model
|
||||
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
dataset_embedding_model_provider, dataset_embedding_model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
@ -804,7 +816,11 @@ class DocumentService:
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
|
||||
dataset.retrieval_model = (
|
||||
knowledge_config.retrieval_model.model_dump()
|
||||
if knowledge_config.retrieval_model
|
||||
else default_retrieval_model
|
||||
) # type: ignore
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
|
@ -27,7 +27,7 @@ class WorkflowAppService:
|
||||
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||
|
||||
if keyword:
|
||||
keyword_like_val = f"%{args['keyword'][:30]}%"
|
||||
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
|
||||
keyword_conditions = [
|
||||
WorkflowRun.inputs.ilike(keyword_like_val),
|
||||
WorkflowRun.outputs.ilike(keyword_like_val),
|
||||
|
@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "remove":
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
|
@ -0,0 +1,55 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GPUStackSpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackSpeech2TextModel()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
file = Path(audio_file_path).read_bytes()
|
||||
|
||||
result = model.invoke(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
file=file,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackText2SpeechModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="cosyvoice-300m-sft",
|
||||
tenant_id="test",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
content_text="Hello world",
|
||||
voice="Chinese Female",
|
||||
)
|
||||
|
||||
content = b""
|
||||
for chunk in result:
|
||||
content += chunk
|
||||
|
||||
assert content != b""
|
@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest):
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# milvus dos not support full text searching yet in < 2.3.x
|
||||
# milvus support BM25 full text search after version 2.5.0-beta
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
assert len(hits_by_full_text) >= 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
|
@ -2,7 +2,7 @@ version: '3'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.14.2
|
||||
image: langgenius/dify-api:0.15.0
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@ -227,7 +227,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.14.2
|
||||
image: langgenius/dify-api:0.15.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_WEB_URL: ''
|
||||
@ -397,7 +397,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.14.2
|
||||
image: langgenius/dify-web:0.15.0
|
||||
restart: always
|
||||
environment:
|
||||
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
||||
|
@ -105,6 +105,9 @@ FILES_ACCESS_TIMEOUT=300
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# Refresh token expiration time in days
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
@ -123,10 +126,13 @@ DIFY_PORT=5001
|
||||
# The number of API server workers, i.e., the number of workers.
|
||||
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
|
||||
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
|
||||
SERVER_WORKER_AMOUNT=
|
||||
SERVER_WORKER_AMOUNT=1
|
||||
|
||||
# Defaults to gevent. If using windows, it can be switched to sync or solo.
|
||||
SERVER_WORKER_CLASS=
|
||||
SERVER_WORKER_CLASS=gevent
|
||||
|
||||
# Default number of worker connections, the default is 10.
|
||||
SERVER_WORKER_CONNECTIONS=10
|
||||
|
||||
# Similar to SERVER_WORKER_CLASS.
|
||||
# If using windows, it can be switched to sync or solo.
|
||||
@ -377,7 +383,7 @@ SUPABASE_URL=your-server-url
|
||||
# ------------------------------
|
||||
|
||||
# The type of vector store to use.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||
@ -397,6 +403,7 @@ MILVUS_URI=http://127.0.0.1:19530
|
||||
MILVUS_TOKEN=
|
||||
MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_ENABLE_HYBRID_SEARCH=False
|
||||
|
||||
# MyScale configuration, only available when VECTOR_STORE is `myscale`
|
||||
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
|
||||
@ -506,7 +513,7 @@ TENCENT_VECTOR_DB_SHARD=1
|
||||
TENCENT_VECTOR_DB_REPLICAS=2
|
||||
|
||||
# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch`
|
||||
ELASTICSEARCH_HOST=0.0.0.0
|
||||
ELASTICSEARCH_HOST=elasticsearch
|
||||
ELASTICSEARCH_PORT=9200
|
||||
ELASTICSEARCH_USERNAME=elastic
|
||||
ELASTICSEARCH_PASSWORD=elastic
|
||||
@ -923,6 +930,9 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
|
||||
# The maximum number of top-k value for RAG.
|
||||
TOP_K_MAX_VALUE=10
|
||||
|
||||
# ------------------------------
|
||||
# Plugin Daemon Configuration
|
||||
# ------------------------------
|
||||
@ -947,3 +957,4 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
|
||||
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev
|
||||
|
||||
|
@ -73,6 +73,7 @@ services:
|
||||
CSP_WHITELIST: ${CSP_WHITELIST:-}
|
||||
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
||||
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
|
||||
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
|
||||
|
||||
# The postgres database.
|
||||
db:
|
||||
@ -92,7 +93,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/db/data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ['CMD', 'pg_isready']
|
||||
test: [ 'CMD', 'pg_isready' ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
@ -111,7 +112,7 @@ services:
|
||||
# Set the redis password when startup redis server.
|
||||
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
|
||||
healthcheck:
|
||||
test: ['CMD', 'redis-cli', 'ping']
|
||||
test: [ 'CMD', 'redis-cli', 'ping' ]
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
@ -131,7 +132,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/sandbox/dependencies:/dependencies
|
||||
healthcheck:
|
||||
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health']
|
||||
test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
|
||||
@ -167,12 +168,7 @@ services:
|
||||
volumes:
|
||||
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
|
||||
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
||||
entrypoint:
|
||||
[
|
||||
'sh',
|
||||
'-c',
|
||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
||||
]
|
||||
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||
environment:
|
||||
# pls clearly modify the squid env vars to fit your network environment.
|
||||
HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
|
||||
@ -201,8 +197,8 @@ services:
|
||||
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
|
||||
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
|
||||
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
|
||||
entrypoint: ['/docker-entrypoint.sh']
|
||||
command: ['tail', '-f', '/dev/null']
|
||||
entrypoint: [ '/docker-entrypoint.sh' ]
|
||||
command: [ 'tail', '-f', '/dev/null' ]
|
||||
|
||||
# The nginx reverse proxy.
|
||||
# used for reverse proxying the API service and Web service.
|
||||
@ -219,12 +215,7 @@ services:
|
||||
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
|
||||
- ./volumes/certbot/conf:/etc/letsencrypt
|
||||
- ./volumes/certbot/www:/var/www/html
|
||||
entrypoint:
|
||||
[
|
||||
'sh',
|
||||
'-c',
|
||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
||||
]
|
||||
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||
environment:
|
||||
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
|
||||
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
|
||||
@ -316,7 +307,7 @@ services:
|
||||
working_dir: /opt/couchbase
|
||||
stdin_open: true
|
||||
tty: true
|
||||
entrypoint: [""]
|
||||
entrypoint: [ "" ]
|
||||
command: sh -c "/opt/couchbase/init/init-cbserver.sh"
|
||||
volumes:
|
||||
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
|
||||
@ -345,7 +336,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/pgvector/data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ['CMD', 'pg_isready']
|
||||
test: [ 'CMD', 'pg_isready' ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
@ -367,7 +358,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ['CMD', 'pg_isready']
|
||||
test: [ 'CMD', 'pg_isready' ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
@ -432,7 +423,7 @@ services:
|
||||
- ./volumes/milvus/etcd:/etcd
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||
healthcheck:
|
||||
test: ['CMD', 'etcdctl', 'endpoint', 'health']
|
||||
test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
@ -451,7 +442,7 @@ services:
|
||||
- ./volumes/milvus/minio:/minio_data
|
||||
command: minio server /minio_data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live']
|
||||
test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
@ -463,7 +454,7 @@ services:
|
||||
image: milvusdb/milvus:v2.3.1
|
||||
profiles:
|
||||
- milvus
|
||||
command: ['milvus', 'run', 'standalone']
|
||||
command: [ 'milvus', 'run', 'standalone' ]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
|
||||
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
|
||||
@ -471,7 +462,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/milvus/milvus:/var/lib/milvus
|
||||
healthcheck:
|
||||
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz']
|
||||
test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
|
||||
interval: 30s
|
||||
start_period: 90s
|
||||
timeout: 20s
|
||||
@ -559,7 +550,7 @@ services:
|
||||
ports:
|
||||
- ${ELASTICSEARCH_PORT:-9200}:9200
|
||||
healthcheck:
|
||||
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty']
|
||||
test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 50
|
||||
@ -587,7 +578,7 @@ services:
|
||||
ports:
|
||||
- ${KIBANA_PORT:-5601}:5601
|
||||
healthcheck:
|
||||
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1']
|
||||
test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
@ -27,12 +27,14 @@ x-shared-env: &shared-api-worker-env
|
||||
MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true}
|
||||
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30}
|
||||
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
|
||||
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
|
||||
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
|
||||
DIFY_PORT: ${DIFY_PORT:-5001}
|
||||
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-}
|
||||
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-}
|
||||
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1}
|
||||
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent}
|
||||
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
|
||||
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
|
||||
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
|
||||
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
|
||||
@ -136,6 +138,7 @@ x-shared-env: &shared-api-worker-env
|
||||
MILVUS_TOKEN: ${MILVUS_TOKEN:-}
|
||||
MILVUS_USER: ${MILVUS_USER:-root}
|
||||
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus}
|
||||
MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False}
|
||||
MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
|
||||
MYSCALE_PORT: ${MYSCALE_PORT:-8123}
|
||||
MYSCALE_USER: ${MYSCALE_USER:-default}
|
||||
@ -401,6 +404,7 @@ x-shared-env: &shared-api-worker-env
|
||||
ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}}
|
||||
MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true}
|
||||
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
||||
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10}
|
||||
|
||||
services:
|
||||
# API service
|
||||
@ -476,6 +480,7 @@ services:
|
||||
CSP_WHITELIST: ${CSP_WHITELIST:-}
|
||||
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
||||
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
|
||||
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
|
||||
|
||||
# The postgres database.
|
||||
db:
|
||||
@ -495,7 +500,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/db/data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ['CMD', 'pg_isready']
|
||||
test: [ 'CMD', 'pg_isready' ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
@ -514,7 +519,7 @@ services:
|
||||
# Set the redis password when startup redis server.
|
||||
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
|
||||
healthcheck:
|
||||
test: ['CMD', 'redis-cli', 'ping']
|
||||
test: [ 'CMD', 'redis-cli', 'ping' ]
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
@ -534,7 +539,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/sandbox/dependencies:/dependencies
|
||||
healthcheck:
|
||||
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health']
|
||||
test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
|
||||
@ -571,12 +576,7 @@ services:
|
||||
volumes:
|
||||
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
|
||||
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
||||
entrypoint:
|
||||
[
|
||||
'sh',
|
||||
'-c',
|
||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
||||
]
|
||||
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||
environment:
|
||||
# pls clearly modify the squid env vars to fit your network environment.
|
||||
HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
|
||||
@ -605,8 +605,8 @@ services:
|
||||
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
|
||||
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
|
||||
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
|
||||
entrypoint: ['/docker-entrypoint.sh']
|
||||
command: ['tail', '-f', '/dev/null']
|
||||
entrypoint: [ '/docker-entrypoint.sh' ]
|
||||
command: [ 'tail', '-f', '/dev/null' ]
|
||||
|
||||
# The nginx reverse proxy.
|
||||
# used for reverse proxying the API service and Web service.
|
||||
@ -623,12 +623,7 @@ services:
|
||||
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
|
||||
- ./volumes/certbot/conf:/etc/letsencrypt
|
||||
- ./volumes/certbot/www:/var/www/html
|
||||
entrypoint:
|
||||
[
|
||||
'sh',
|
||||
'-c',
|
||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
||||
]
|
||||
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||
environment:
|
||||
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
|
||||
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
|
||||
@ -720,7 +715,7 @@ services:
|
||||
working_dir: /opt/couchbase
|
||||
stdin_open: true
|
||||
tty: true
|
||||
entrypoint: [""]
|
||||
entrypoint: [ "" ]
|
||||
command: sh -c "/opt/couchbase/init/init-cbserver.sh"
|
||||
volumes:
|
||||
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
|
||||
@ -749,7 +744,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/pgvector/data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ['CMD', 'pg_isready']
|
||||
test: [ 'CMD', 'pg_isready' ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
@ -771,7 +766,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ['CMD', 'pg_isready']
|
||||
test: [ 'CMD', 'pg_isready' ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
@ -836,7 +831,7 @@ services:
|
||||
- ./volumes/milvus/etcd:/etcd
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||
healthcheck:
|
||||
test: ['CMD', 'etcdctl', 'endpoint', 'health']
|
||||
test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
@ -855,7 +850,7 @@ services:
|
||||
- ./volumes/milvus/minio:/minio_data
|
||||
command: minio server /minio_data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live']
|
||||
test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
@ -864,10 +859,10 @@ services:
|
||||
|
||||
milvus-standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.3.1
|
||||
image: milvusdb/milvus:v2.5.0-beta
|
||||
profiles:
|
||||
- milvus
|
||||
command: ['milvus', 'run', 'standalone']
|
||||
command: [ 'milvus', 'run', 'standalone' ]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
|
||||
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
|
||||
@ -875,7 +870,7 @@ services:
|
||||
volumes:
|
||||
- ./volumes/milvus/milvus:/var/lib/milvus
|
||||
healthcheck:
|
||||
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz']
|
||||
test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
|
||||
interval: 30s
|
||||
start_period: 90s
|
||||
timeout: 20s
|
||||
@ -948,22 +943,30 @@ services:
|
||||
container_name: elasticsearch
|
||||
profiles:
|
||||
- elasticsearch
|
||||
- elasticsearch-ja
|
||||
restart: always
|
||||
volumes:
|
||||
- ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
||||
- dify_es01_data:/usr/share/elasticsearch/data
|
||||
environment:
|
||||
ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
|
||||
VECTOR_STORE: ${VECTOR_STORE:-}
|
||||
cluster.name: dify-es-cluster
|
||||
node.name: dify-es0
|
||||
discovery.type: single-node
|
||||
xpack.license.self_generated.type: trial
|
||||
xpack.license.self_generated.type: basic
|
||||
xpack.security.enabled: 'true'
|
||||
xpack.security.enrollment.enabled: 'false'
|
||||
xpack.security.http.ssl.enabled: 'false'
|
||||
ports:
|
||||
- ${ELASTICSEARCH_PORT:-9200}:9200
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 2g
|
||||
entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ]
|
||||
healthcheck:
|
||||
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty']
|
||||
test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 50
|
||||
@ -991,7 +994,7 @@ services:
|
||||
ports:
|
||||
- ${KIBANA_PORT:-5601}:5601
|
||||
healthcheck:
|
||||
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1']
|
||||
test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
25
docker/elasticsearch/docker-entrypoint.sh
Executable file
25
docker/elasticsearch/docker-entrypoint.sh
Executable file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ "${VECTOR_STORE}" = "elasticsearch-ja" ]; then
|
||||
# Check if the ICU tokenizer plugin is installed
|
||||
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-icu; then
|
||||
printf '%s\n' "Installing the ICU tokenizer plugin"
|
||||
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-icu; then
|
||||
printf '%s\n' "Failed to install the ICU tokenizer plugin"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
# Check if the Japanese language analyzer plugin is installed
|
||||
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-kuromoji; then
|
||||
printf '%s\n' "Installing the Japanese language analyzer plugin"
|
||||
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-kuromoji; then
|
||||
printf '%s\n' "Failed to install the Japanese language analyzer plugin"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# Run the original entrypoint script
|
||||
exec /bin/tini -- /usr/local/bin/docker-entrypoint.sh
|
@ -25,3 +25,6 @@ NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
|
||||
|
||||
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
|
||||
NEXT_PUBLIC_CSP_WHITELIST=
|
||||
|
||||
# The maximum number of top-k value for RAG.
|
||||
NEXT_PUBLIC_TOP_K_MAX_VALUE=10
|
||||
|
@ -25,16 +25,18 @@ import Input from '@/app/components/base/input'
|
||||
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
|
||||
import TagManagementModal from '@/app/components/base/tag-management'
|
||||
import TagFilter from '@/app/components/base/tag-management/filter'
|
||||
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
|
||||
|
||||
const getKey = (
|
||||
pageIndex: number,
|
||||
previousPageData: AppListResponse,
|
||||
activeTab: string,
|
||||
isCreatedByMe: boolean,
|
||||
tags: string[],
|
||||
keywords: string,
|
||||
) => {
|
||||
if (!pageIndex || previousPageData.has_more) {
|
||||
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords } }
|
||||
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } }
|
||||
|
||||
if (activeTab !== 'all')
|
||||
params.params.mode = activeTab
|
||||
@ -58,6 +60,7 @@ const Apps = () => {
|
||||
defaultTab: 'all',
|
||||
})
|
||||
const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState()
|
||||
const [isCreatedByMe, setIsCreatedByMe] = useState(false)
|
||||
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
|
||||
const [searchKeywords, setSearchKeywords] = useState(keywords)
|
||||
const setKeywords = useCallback((keywords: string) => {
|
||||
@ -68,7 +71,7 @@ const Apps = () => {
|
||||
}, [setQuery])
|
||||
|
||||
const { data, isLoading, setSize, mutate } = useSWRInfinite(
|
||||
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords),
|
||||
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords),
|
||||
fetchAppList,
|
||||
{ revalidateFirstPage: true },
|
||||
)
|
||||
@ -132,6 +135,12 @@ const Apps = () => {
|
||||
options={options}
|
||||
/>
|
||||
<div className='flex items-center gap-2'>
|
||||
<CheckboxWithLabel
|
||||
className='mr-2'
|
||||
label={t('app.showMyCreatedAppsOnly')}
|
||||
isChecked={isCreatedByMe}
|
||||
onChange={() => setIsCreatedByMe(!isCreatedByMe)}
|
||||
/>
|
||||
<TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} />
|
||||
<Input
|
||||
showLeftIcon
|
||||
|
@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
|
||||
- <code>economy</code> Economy: Build using inverted index of keyword table index
|
||||
</Property>
|
||||
<Property name='doc_form' type='string' key='doc_form'>
|
||||
Format of indexed content
|
||||
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
|
||||
- <code>hierarchical_model</code> Parent-child mode
|
||||
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
|
||||
</Property>
|
||||
<Property name='doc_language' type='string' key='doc_language'>
|
||||
In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
|
||||
</Property>
|
||||
<Property name='process_rule' type='object' key='process_rule'>
|
||||
Processing rules
|
||||
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
|
||||
@ -65,6 +74,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>segmentation</code> (object) Segmentation rules
|
||||
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
||||
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
||||
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
|
||||
- <code>subchunk_segmentation</code> (object) Child chunk rules
|
||||
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
|
||||
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
|
||||
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
|
||||
- <code>economy</code> Economy: Build using inverted index of keyword table index
|
||||
|
||||
- <code>doc_form</code> Format of indexed content
|
||||
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
|
||||
- <code>hierarchical_model</code> Parent-child mode
|
||||
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
|
||||
|
||||
- <code>doc_language</code> In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
|
||||
|
||||
- <code>process_rule</code> Processing rules
|
||||
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
|
||||
- <code>rules</code> (object) Custom rules (in automatic mode, this field is empty)
|
||||
@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>segmentation</code> (object) Segmentation rules
|
||||
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
||||
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
||||
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
|
||||
- <code>subchunk_segmentation</code> (object) Child chunk rules
|
||||
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
|
||||
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
|
||||
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
|
||||
</Property>
|
||||
<Property name='file' type='multipart/form-data' key='file'>
|
||||
Files that need to be uploaded.
|
||||
@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>segmentation</code> (object) Segmentation rules
|
||||
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
||||
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
||||
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
|
||||
- <code>subchunk_segmentation</code> (object) Child chunk rules
|
||||
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
|
||||
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
|
||||
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>segmentation</code> (object) Segmentation rules
|
||||
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
||||
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
||||
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
|
||||
- <code>subchunk_segmentation</code> (object) Child chunk rules
|
||||
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
|
||||
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
|
||||
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
@ -984,7 +1020,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
||||
method='POST'
|
||||
title='Update a Chunk in a Document '
|
||||
title='Update a Chunk in a Document'
|
||||
name='#update_segment'
|
||||
/>
|
||||
<Row>
|
||||
@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional)
|
||||
- <code>keywords</code> (list) Keyword (optional)
|
||||
- <code>enabled</code> (bool) False / true (optional)
|
||||
- <code>regenerate_child_chunks</code> (bool) Whether to regenerate child chunks (optional)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
|
||||
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
|
||||
</Property>
|
||||
<Property name='doc_form' type='string' key='doc_form'>
|
||||
索引内容的形式
|
||||
- <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式
|
||||
- <code>hierarchical_model</code> parent-child 模式
|
||||
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
|
||||
</Property>
|
||||
<Property name='doc_language' type='string' key='doc_language'>
|
||||
在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
|
||||
</Property>
|
||||
<Property name='process_rule' type='object' key='process_rule'>
|
||||
处理规则
|
||||
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义
|
||||
@ -63,8 +72,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>remove_urls_emails</code> 删除 URL、电子邮件地址
|
||||
- <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值
|
||||
- <code>segmentation</code> (object) 分段规则
|
||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 <code>\n</code>
|
||||
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
||||
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
|
||||
- <code>subchunk_segmentation</code> (object) 子分段规则
|
||||
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
|
||||
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
|
||||
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
|
||||
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
|
||||
|
||||
- <code>doc_form</code> 索引内容的形式
|
||||
- <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式
|
||||
- <code>hierarchical_model</code> parent-child 模式
|
||||
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
|
||||
|
||||
- <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
|
||||
|
||||
- <code>process_rule</code> 处理规则
|
||||
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义
|
||||
- <code>rules</code> (object) 自定义规则(自动模式下,该字段为空)
|
||||
@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>segmentation</code> (object) 分段规则
|
||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
||||
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
||||
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
|
||||
- <code>subchunk_segmentation</code> (object) 子分段规则
|
||||
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
|
||||
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
|
||||
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
|
||||
</Property>
|
||||
<Property name='file' type='multipart/form-data' key='file'>
|
||||
需要上传的文件。
|
||||
@ -411,7 +437,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
|
||||
method='POST'
|
||||
title='通过文本更新文档 '
|
||||
title='通过文本更新文档'
|
||||
name='#update-by-text'
|
||||
/>
|
||||
<Row>
|
||||
@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>segmentation</code> (object) 分段规则
|
||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
||||
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
||||
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
|
||||
- <code>subchunk_segmentation</code> (object) 子分段规则
|
||||
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
|
||||
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
|
||||
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
@ -508,7 +539,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
<Heading
|
||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
|
||||
method='POST'
|
||||
title='通过文件更新文档 '
|
||||
title='通过文件更新文档'
|
||||
name='#update-by-file'
|
||||
/>
|
||||
<Row>
|
||||
@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>segmentation</code> (object) 分段规则
|
||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
||||
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
||||
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
|
||||
- <code>subchunk_segmentation</code> (object) 子分段规则
|
||||
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
|
||||
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
|
||||
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
- <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值
|
||||
- <code>keywords</code> (list) 关键字,非必填
|
||||
- <code>enabled</code> (bool) false/true,非必填
|
||||
- <code>regenerate_child_chunks</code> (bool) 是否重新生成子分段,非必填
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
@ -26,13 +26,15 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({
|
||||
const [clientY, setClientY] = useState(0)
|
||||
const [isResizing, setIsResizing] = useState(false)
|
||||
const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect)
|
||||
const [oldHeight, setOldHeight] = useState(height)
|
||||
|
||||
const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => {
|
||||
setClientY(e.clientY)
|
||||
setIsResizing(true)
|
||||
setOldHeight(height)
|
||||
setPrevUserSelectStyle(getComputedStyle(document.body).userSelect)
|
||||
document.body.style.userSelect = 'none'
|
||||
}, [])
|
||||
}, [height])
|
||||
|
||||
const handleStopResize = useCallback(() => {
|
||||
setIsResizing(false)
|
||||
@ -44,8 +46,7 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({
|
||||
return
|
||||
|
||||
const offset = e.clientY - clientY
|
||||
let newHeight = height + offset
|
||||
setClientY(e.clientY)
|
||||
let newHeight = oldHeight + offset
|
||||
if (newHeight < minHeight)
|
||||
newHeight = minHeight
|
||||
onHeightChange(newHeight)
|
||||
|
@ -27,6 +27,7 @@ import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/confi
|
||||
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block'
|
||||
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||
|
||||
export type ISimplePromptInput = {
|
||||
mode: AppType
|
||||
@ -54,6 +55,11 @@ const Prompt: FC<ISimplePromptInput> = ({
|
||||
const { t } = useTranslation()
|
||||
const media = useBreakpoints()
|
||||
const isMobile = media === MediaType.mobile
|
||||
const featuresStore = useFeaturesStore()
|
||||
const {
|
||||
features,
|
||||
setFeatures,
|
||||
} = featuresStore!.getState()
|
||||
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const {
|
||||
@ -137,8 +143,18 @@ const Prompt: FC<ISimplePromptInput> = ({
|
||||
})
|
||||
setModelConfig(newModelConfig)
|
||||
setPrevPromptConfig(modelConfig.configs)
|
||||
if (mode !== AppType.completion)
|
||||
|
||||
if (mode !== AppType.completion) {
|
||||
setIntroduction(res.opening_statement)
|
||||
const newFeatures = produce(features, (draft) => {
|
||||
draft.opening = {
|
||||
...draft.opening,
|
||||
enabled: !!res.opening_statement,
|
||||
opening_statement: res.opening_statement,
|
||||
}
|
||||
})
|
||||
setFeatures(newFeatures)
|
||||
}
|
||||
showAutomaticFalse()
|
||||
}
|
||||
const minHeight = initEditorHeight || 228
|
||||
|
@ -59,36 +59,24 @@ const ConfigContent: FC<Props> = ({
|
||||
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelValid,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel: currentRerankModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
{
|
||||
provider: datasetConfigs.reranking_model?.reranking_provider_name,
|
||||
model: datasetConfigs.reranking_model?.reranking_model_name,
|
||||
},
|
||||
)
|
||||
|
||||
const rerankModel = (() => {
|
||||
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
||||
return {
|
||||
provider_name: datasetConfigs.reranking_model.reranking_provider_name,
|
||||
model_name: datasetConfigs.reranking_model.reranking_model_name,
|
||||
}
|
||||
const rerankModel = useMemo(() => {
|
||||
return {
|
||||
provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '',
|
||||
model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '',
|
||||
}
|
||||
else if (rerankDefaultModel) {
|
||||
return {
|
||||
provider_name: rerankDefaultModel.provider.provider,
|
||||
model_name: rerankDefaultModel.model,
|
||||
}
|
||||
}
|
||||
})()
|
||||
}, [datasetConfigs.reranking_model])
|
||||
|
||||
const handleParamChange = (key: string, value: number) => {
|
||||
if (key === 'top_k') {
|
||||
@ -133,6 +121,12 @@ const ConfigContent: FC<Props> = ({
|
||||
}
|
||||
|
||||
const handleRerankModeChange = (mode: RerankingModeEnum) => {
|
||||
if (mode === datasetConfigs.reranking_mode)
|
||||
return
|
||||
|
||||
if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_mode: mode,
|
||||
@ -162,31 +156,25 @@ const ConfigContent: FC<Props> = ({
|
||||
|
||||
const canManuallyToggleRerank = useMemo(() => {
|
||||
return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|
||||
|| selectedDatasetsMode.allExternal
|
||||
|| selectedDatasetsMode.allExternal
|
||||
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
|
||||
|
||||
const showRerankModel = useMemo(() => {
|
||||
if (!canManuallyToggleRerank)
|
||||
return true
|
||||
else if (canManuallyToggleRerank && !isRerankDefaultModelValid)
|
||||
return false
|
||||
|
||||
return datasetConfigs.reranking_enable
|
||||
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
|
||||
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
|
||||
|
||||
const handleDisabledSwitchClick = useCallback(() => {
|
||||
if (!currentRerankModel && !showRerankModel)
|
||||
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
|
||||
if (!currentRerankModel && enable)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
}, [currentRerankModel, showRerankModel, t])
|
||||
|
||||
useEffect(() => {
|
||||
if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: showRerankModel,
|
||||
})
|
||||
}
|
||||
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: enable,
|
||||
})
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentRerankModel, datasetConfigs, onChange])
|
||||
|
||||
return (
|
||||
<div>
|
||||
@ -267,24 +255,12 @@ const ConfigContent: FC<Props> = ({
|
||||
<div className='flex items-center'>
|
||||
{
|
||||
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
|
||||
<div
|
||||
className='flex items-center'
|
||||
onClick={handleDisabledSwitchClick}
|
||||
>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={showRerankModel}
|
||||
disabled={!currentRerankModel || !canManuallyToggleRerank}
|
||||
onChange={(v) => {
|
||||
if (canManuallyToggleRerank) {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: v,
|
||||
})
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={showRerankModel}
|
||||
disabled={!canManuallyToggleRerank}
|
||||
onChange={handleDisabledSwitchClick}
|
||||
/>
|
||||
)
|
||||
}
|
||||
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
|
||||
@ -298,21 +274,24 @@ const ConfigContent: FC<Props> = ({
|
||||
triggerClassName='ml-1 w-4 h-4'
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
|
||||
onSelect={(v) => {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_model: {
|
||||
reranking_provider_name: v.provider,
|
||||
reranking_model_name: v.model,
|
||||
},
|
||||
})
|
||||
}}
|
||||
modelList={rerankModelList}
|
||||
/>
|
||||
</div>
|
||||
{
|
||||
showRerankModel && (
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
|
||||
onSelect={(v) => {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_model: {
|
||||
reranking_provider_name: v.provider,
|
||||
reranking_model_name: v.model,
|
||||
},
|
||||
})
|
||||
}}
|
||||
modelList={rerankModelList}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import Modal from '@/app/components/base/modal'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { RerankingModeEnum } from '@/models/datasets'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
@ -41,17 +41,27 @@ const ParamsConfig = ({
|
||||
}, [datasetConfigs])
|
||||
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelValid,
|
||||
modelList: rerankModelList,
|
||||
currentModel: rerankDefaultModel,
|
||||
currentProvider: rerankDefaultProvider,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel: isCurrentRerankModelValid,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
{
|
||||
provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '',
|
||||
model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '',
|
||||
},
|
||||
)
|
||||
|
||||
const isValid = () => {
|
||||
let errMsg = ''
|
||||
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
|
||||
if (tempDataSetConfigs.reranking_enable
|
||||
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
|
||||
&& !isRerankDefaultModelValid
|
||||
&& !isCurrentRerankModelValid
|
||||
)
|
||||
errMsg = t('appDebug.datasetConfig.rerankModelRequired')
|
||||
}
|
||||
@ -66,16 +76,7 @@ const ParamsConfig = ({
|
||||
const handleSave = () => {
|
||||
if (!isValid())
|
||||
return
|
||||
const config = { ...tempDataSetConfigs }
|
||||
if (config.retrieval_model === RETRIEVE_TYPE.multiWay
|
||||
&& config.reranking_mode === RerankingModeEnum.RerankingModel
|
||||
&& !config.reranking_model) {
|
||||
config.reranking_model = {
|
||||
reranking_provider_name: rerankDefaultModel?.provider?.provider,
|
||||
reranking_model_name: rerankDefaultModel?.model,
|
||||
} as any
|
||||
}
|
||||
setDatasetConfigs(config)
|
||||
setDatasetConfigs(tempDataSetConfigs)
|
||||
setRerankSettingModalOpen(false)
|
||||
}
|
||||
|
||||
@ -94,14 +95,14 @@ const ParamsConfig = ({
|
||||
reranking_enable: restConfigs.reranking_enable,
|
||||
}, selectedDatasets, selectedDatasets, {
|
||||
provider: rerankDefaultProvider?.provider,
|
||||
model: isRerankDefaultModelValid?.model,
|
||||
model: rerankDefaultModel?.model,
|
||||
})
|
||||
|
||||
setTempDataSetConfigs({
|
||||
...retrievalConfig,
|
||||
reranking_model: restConfigs.reranking_model && {
|
||||
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
|
||||
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
|
||||
reranking_model: {
|
||||
reranking_provider_name: retrievalConfig.reranking_model?.provider || '',
|
||||
reranking_model_name: retrievalConfig.reranking_model?.model || '',
|
||||
},
|
||||
retrieval_model,
|
||||
score_threshold_enabled,
|
||||
|
@ -29,7 +29,7 @@ const WeightedScore = ({
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className='px-3 pt-5 h-[52px] space-x-3 rounded-lg border border-components-panel-border'>
|
||||
<div className='px-3 pt-5 pb-2 space-x-3 rounded-lg border border-components-panel-border'>
|
||||
<Slider
|
||||
className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')}
|
||||
max={1.0}
|
||||
@ -39,7 +39,7 @@ const WeightedScore = ({
|
||||
onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })}
|
||||
trackClassName='weightedScoreSliderTrack'
|
||||
/>
|
||||
<div className='flex justify-between mt-1'>
|
||||
<div className='flex justify-between mt-3'>
|
||||
<div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'>
|
||||
<div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}>
|
||||
{t('dataset.weightedScore.semantic')}
|
||||
|
@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import { type DataSet, RerankingModeEnum } from '@/models/datasets'
|
||||
import { type DataSet } from '@/models/datasets'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { updateDatasetSetting } from '@/service/datasets'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
@ -21,7 +21,7 @@ import type { RetrievalConfig } from '@/types/app'
|
||||
import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings'
|
||||
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
||||
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
||||
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
||||
import PermissionSelector from '@/app/components/datasets/settings/permission-selector'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
@ -99,8 +99,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
}
|
||||
if (
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
|
||||
rerankModelList,
|
||||
retrievalConfig,
|
||||
indexMethod,
|
||||
@ -109,14 +107,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
|
||||
return
|
||||
}
|
||||
const postRetrievalConfig = ensureRerankModelSelected({
|
||||
rerankDefaultModel: rerankDefaultModel!,
|
||||
retrievalConfig: {
|
||||
...retrievalConfig,
|
||||
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
|
||||
},
|
||||
indexMethod,
|
||||
})
|
||||
try {
|
||||
setLoading(true)
|
||||
const { id, name, description, permission } = localeCurrentDataset
|
||||
@ -128,8 +118,8 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
permission,
|
||||
indexing_technique: indexMethod,
|
||||
retrieval_model: {
|
||||
...postRetrievalConfig,
|
||||
score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0,
|
||||
...retrievalConfig,
|
||||
score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0,
|
||||
},
|
||||
embedding_model: localeCurrentDataset.embedding_model,
|
||||
embedding_model_provider: localeCurrentDataset.embedding_model_provider,
|
||||
@ -157,7 +147,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
onSave({
|
||||
...localeCurrentDataset,
|
||||
indexing_technique: indexMethod,
|
||||
retrieval_model_dict: postRetrievalConfig,
|
||||
retrieval_model_dict: retrievalConfig,
|
||||
})
|
||||
}
|
||||
catch (e) {
|
||||
|
@ -287,9 +287,9 @@ const Configuration: FC = () => {
|
||||
|
||||
setDatasetConfigs({
|
||||
...retrievalConfig,
|
||||
reranking_model: restConfigs.reranking_model && {
|
||||
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
|
||||
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
|
||||
reranking_model: {
|
||||
reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
|
||||
reranking_model_name: retrievalConfig?.reranking_model?.model || '',
|
||||
},
|
||||
retrieval_model,
|
||||
score_threshold_enabled,
|
||||
|
@ -39,6 +39,7 @@ type ChatInputAreaProps = {
|
||||
inputs?: Record<string, any>
|
||||
inputsForm?: InputForm[]
|
||||
theme?: Theme | null
|
||||
isResponding?: boolean
|
||||
}
|
||||
const ChatInputArea = ({
|
||||
showFeatureBar,
|
||||
@ -51,6 +52,7 @@ const ChatInputArea = ({
|
||||
inputs = {},
|
||||
inputsForm = [],
|
||||
theme,
|
||||
isResponding,
|
||||
}: ChatInputAreaProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useToastContext()
|
||||
@ -77,6 +79,11 @@ const ChatInputArea = ({
|
||||
const historyRef = useRef([''])
|
||||
const [currentIndex, setCurrentIndex] = useState(-1)
|
||||
const handleSend = () => {
|
||||
if (isResponding) {
|
||||
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
|
||||
return
|
||||
}
|
||||
|
||||
if (onSend) {
|
||||
const { files, setFiles } = filesStore.getState()
|
||||
if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) {
|
||||
@ -116,7 +123,7 @@ const ChatInputArea = ({
|
||||
setQuery(historyRef.current[currentIndex + 1])
|
||||
}
|
||||
else if (currentIndex === historyRef.current.length - 1) {
|
||||
// If it is the last element, clear the input box
|
||||
// If it is the last element, clear the input box
|
||||
setCurrentIndex(historyRef.current.length)
|
||||
setQuery('')
|
||||
}
|
||||
@ -169,6 +176,7 @@ const ChatInputArea = ({
|
||||
'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none',
|
||||
)}
|
||||
placeholder={t('common.chat.inputPlaceholder') || ''}
|
||||
autoFocus
|
||||
autoSize={{ minRows: 1 }}
|
||||
onResize={handleTextareaResize}
|
||||
value={query}
|
||||
|
@ -292,6 +292,7 @@ const Chat: FC<ChatProps> = ({
|
||||
inputs={inputs}
|
||||
inputsForm={inputsForm}
|
||||
theme={themeBuilder?.theme}
|
||||
isResponding={isResponding}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
@ -28,8 +28,8 @@ const Question: FC<QuestionProps> = ({
|
||||
} = item
|
||||
|
||||
return (
|
||||
<div className='flex justify-end mb-2 last:mb-0 pl-10'>
|
||||
<div className='group relative mr-4'>
|
||||
<div className='flex justify-end mb-2 last:mb-0 pl-14'>
|
||||
<div className='group relative mr-4 max-w-full'>
|
||||
<div
|
||||
className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900'
|
||||
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
|
||||
|
@ -111,9 +111,9 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }
|
||||
}
|
||||
else if (language === 'echarts') {
|
||||
return (
|
||||
<div style={{ minHeight: '350px', minWidth: '700px' }}>
|
||||
<div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}>
|
||||
<ErrorBoundary>
|
||||
<ReactEcharts option={chartData} />
|
||||
<ReactEcharts option={chartData} style={{ minWidth: '700px' }} />
|
||||
</ErrorBoundary>
|
||||
</div>
|
||||
)
|
||||
|
@ -11,11 +11,17 @@ type Props = {
|
||||
enable: boolean
|
||||
}
|
||||
|
||||
const maxTopK = (() => {
|
||||
const configValue = parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10)
|
||||
if (configValue && !isNaN(configValue))
|
||||
return configValue
|
||||
return 10
|
||||
})()
|
||||
const VALUE_LIMIT = {
|
||||
default: 2,
|
||||
step: 1,
|
||||
min: 1,
|
||||
max: 10,
|
||||
max: maxTopK,
|
||||
}
|
||||
|
||||
const key = 'top_k'
|
||||
|
@ -6,14 +6,10 @@ import type {
|
||||
import { RerankingModeEnum } from '@/models/datasets'
|
||||
|
||||
export const isReRankModelSelected = ({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelValid,
|
||||
retrievalConfig,
|
||||
rerankModelList,
|
||||
indexMethod,
|
||||
}: {
|
||||
rerankDefaultModel?: DefaultModelResponse
|
||||
isRerankDefaultModelValid: boolean
|
||||
retrievalConfig: RetrievalConfig
|
||||
rerankModelList: Model[]
|
||||
indexMethod?: string
|
||||
@ -25,12 +21,17 @@ export const isReRankModelSelected = ({
|
||||
return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
|
||||
}
|
||||
|
||||
if (isRerankDefaultModelValid)
|
||||
return !!rerankDefaultModel
|
||||
|
||||
return false
|
||||
})()
|
||||
|
||||
if (
|
||||
indexMethod === 'high_quality'
|
||||
&& ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrievalConfig.search_method))
|
||||
&& retrievalConfig.reranking_enable
|
||||
&& !rerankModelSelected
|
||||
)
|
||||
return false
|
||||
|
||||
if (
|
||||
indexMethod === 'high_quality'
|
||||
&& (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore)
|
||||
|
@ -10,11 +10,13 @@ import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import type { RetrievalConfig } from '@/types/app'
|
||||
|
||||
type Props = {
|
||||
disabled?: boolean
|
||||
value: RetrievalConfig
|
||||
onChange: (value: RetrievalConfig) => void
|
||||
}
|
||||
|
||||
const EconomicalRetrievalMethodConfig: FC<Props> = ({
|
||||
disabled = false,
|
||||
value,
|
||||
onChange,
|
||||
}) => {
|
||||
@ -22,7 +24,8 @@ const EconomicalRetrievalMethodConfig: FC<Props> = ({
|
||||
|
||||
return (
|
||||
<div className='space-y-2'>
|
||||
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
|
||||
<OptionCard
|
||||
disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
|
||||
title={t('dataset.retrieval.invertedIndex.title')}
|
||||
description={t('dataset.retrieval.invertedIndex.description')} isActive
|
||||
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
||||
|
@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Image from 'next/image'
|
||||
import RetrievalParamConfig from '../retrieval-param-config'
|
||||
@ -10,7 +10,7 @@ import { retrievalIcon } from '../../create/icons'
|
||||
import type { RetrievalConfig } from '@/types/app'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
DEFAULT_WEIGHTED_SCORE,
|
||||
@ -20,54 +20,87 @@ import {
|
||||
import Badge from '@/app/components/base/badge'
|
||||
|
||||
type Props = {
|
||||
disabled?: boolean
|
||||
value: RetrievalConfig
|
||||
onChange: (value: RetrievalConfig) => void
|
||||
}
|
||||
|
||||
const RetrievalMethodConfig: FC<Props> = ({
|
||||
value: passValue,
|
||||
disabled = false,
|
||||
value,
|
||||
onChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { supportRetrievalMethods } = useProviderContext()
|
||||
const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank)
|
||||
const value = (() => {
|
||||
if (!passValue.reranking_model.reranking_model_name) {
|
||||
return {
|
||||
...passValue,
|
||||
reranking_model: {
|
||||
reranking_provider_name: rerankDefaultModel?.provider.provider || '',
|
||||
reranking_model_name: rerankDefaultModel?.model || '',
|
||||
},
|
||||
reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore),
|
||||
weights: passValue.weights || {
|
||||
weight_type: WeightedScoreEnum.Customized,
|
||||
vector_setting: {
|
||||
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
||||
embedding_provider_name: '',
|
||||
embedding_model_name: '',
|
||||
},
|
||||
keyword_setting: {
|
||||
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
||||
},
|
||||
},
|
||||
}
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelValid,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const onSwitch = useCallback((retrieveMethod: RETRIEVE_METHOD) => {
|
||||
if ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrieveMethod)) {
|
||||
onChange({
|
||||
...value,
|
||||
search_method: retrieveMethod,
|
||||
...(!value.reranking_model.reranking_model_name
|
||||
? {
|
||||
reranking_model: {
|
||||
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
|
||||
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
|
||||
},
|
||||
reranking_enable: !!isRerankDefaultModelValid,
|
||||
}
|
||||
: {
|
||||
reranking_enable: true,
|
||||
}),
|
||||
})
|
||||
}
|
||||
return passValue
|
||||
})()
|
||||
if (retrieveMethod === RETRIEVE_METHOD.hybrid) {
|
||||
onChange({
|
||||
...value,
|
||||
search_method: retrieveMethod,
|
||||
...(!value.reranking_model.reranking_model_name
|
||||
? {
|
||||
reranking_model: {
|
||||
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
|
||||
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
|
||||
},
|
||||
reranking_enable: !!isRerankDefaultModelValid,
|
||||
reranking_mode: isRerankDefaultModelValid ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore,
|
||||
}
|
||||
: {
|
||||
reranking_enable: true,
|
||||
reranking_mode: RerankingModeEnum.RerankingModel,
|
||||
}),
|
||||
...(!value.weights
|
||||
? {
|
||||
weights: {
|
||||
weight_type: WeightedScoreEnum.Customized,
|
||||
vector_setting: {
|
||||
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
||||
embedding_provider_name: '',
|
||||
embedding_model_name: '',
|
||||
},
|
||||
keyword_setting: {
|
||||
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
||||
},
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
})
|
||||
}
|
||||
}, [value, rerankDefaultModel, isRerankDefaultModelValid, onChange])
|
||||
|
||||
return (
|
||||
<div className='space-y-2'>
|
||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
|
||||
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
|
||||
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
|
||||
title={t('dataset.retrieval.semantic_search.title')}
|
||||
description={t('dataset.retrieval.semantic_search.description')}
|
||||
isActive={
|
||||
value.search_method === RETRIEVE_METHOD.semantic
|
||||
}
|
||||
onSwitched={() => onChange({
|
||||
...value,
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
})}
|
||||
onSwitched={() => onSwitch(RETRIEVE_METHOD.semantic)}
|
||||
effectImg={Effect.src}
|
||||
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
||||
>
|
||||
@ -78,17 +111,14 @@ const RetrievalMethodConfig: FC<Props> = ({
|
||||
/>
|
||||
</OptionCard>
|
||||
)}
|
||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
|
||||
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
|
||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.fullText) && (
|
||||
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
|
||||
title={t('dataset.retrieval.full_text_search.title')}
|
||||
description={t('dataset.retrieval.full_text_search.description')}
|
||||
isActive={
|
||||
value.search_method === RETRIEVE_METHOD.fullText
|
||||
}
|
||||
onSwitched={() => onChange({
|
||||
...value,
|
||||
search_method: RETRIEVE_METHOD.fullText,
|
||||
})}
|
||||
onSwitched={() => onSwitch(RETRIEVE_METHOD.fullText)}
|
||||
effectImg={Effect.src}
|
||||
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
||||
>
|
||||
@ -99,8 +129,8 @@ const RetrievalMethodConfig: FC<Props> = ({
|
||||
/>
|
||||
</OptionCard>
|
||||
)}
|
||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
|
||||
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
|
||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.hybrid) && (
|
||||
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
|
||||
title={
|
||||
<div className='flex items-center space-x-1'>
|
||||
<div>{t('dataset.retrieval.hybrid_search.title')}</div>
|
||||
@ -110,11 +140,7 @@ const RetrievalMethodConfig: FC<Props> = ({
|
||||
description={t('dataset.retrieval.hybrid_search.description')} isActive={
|
||||
value.search_method === RETRIEVE_METHOD.hybrid
|
||||
}
|
||||
onSwitched={() => onChange({
|
||||
...value,
|
||||
search_method: RETRIEVE_METHOD.hybrid,
|
||||
reranking_enable: true,
|
||||
})}
|
||||
onSwitched={() => onSwitch(RETRIEVE_METHOD.hybrid)}
|
||||
effectImg={Effect.src}
|
||||
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
||||
>
|
||||
|
@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import React, { useCallback, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import Image from 'next/image'
|
||||
@ -39,8 +39,8 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
const { t } = useTranslation()
|
||||
const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid
|
||||
const isEconomical = type === RETRIEVE_METHOD.invertedIndex
|
||||
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
modelList: rerankModelList,
|
||||
} = useModelListAndDefaultModel(ModelTypeEnum.rerank)
|
||||
|
||||
@ -48,35 +48,28 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
currentModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
{
|
||||
provider: value.reranking_model?.reranking_provider_name ?? '',
|
||||
model: value.reranking_model?.reranking_model_name ?? '',
|
||||
},
|
||||
)
|
||||
|
||||
const handleDisabledSwitchClick = useCallback(() => {
|
||||
if (!currentModel)
|
||||
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
|
||||
if (enable && !currentModel)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
}, [currentModel, rerankDefaultModel, t])
|
||||
onChange({
|
||||
...value,
|
||||
reranking_enable: enable,
|
||||
})
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentModel, onChange, value])
|
||||
|
||||
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
|
||||
|
||||
const rerankModel = (() => {
|
||||
if (value.reranking_model) {
|
||||
return {
|
||||
provider_name: value.reranking_model.reranking_provider_name,
|
||||
model_name: value.reranking_model.reranking_model_name,
|
||||
}
|
||||
const rerankModel = useMemo(() => {
|
||||
return {
|
||||
provider_name: value.reranking_model.reranking_provider_name,
|
||||
model_name: value.reranking_model.reranking_model_name,
|
||||
}
|
||||
else if (rerankDefaultModel) {
|
||||
return {
|
||||
provider_name: rerankDefaultModel.provider.provider,
|
||||
model_name: rerankDefaultModel.model,
|
||||
}
|
||||
}
|
||||
})()
|
||||
}, [value.reranking_model])
|
||||
|
||||
const handleChangeRerankMode = (v: RerankingModeEnum) => {
|
||||
if (v === value.reranking_mode)
|
||||
@ -100,6 +93,8 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
},
|
||||
}
|
||||
}
|
||||
if (v === RerankingModeEnum.RerankingModel && !currentModel)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
onChange(result)
|
||||
}
|
||||
|
||||
@ -122,22 +117,11 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
<div>
|
||||
<div className='flex items-center space-x-2 mb-2'>
|
||||
{canToggleRerankModalEnable && (
|
||||
<div
|
||||
className='flex items-center'
|
||||
onClick={handleDisabledSwitchClick}
|
||||
>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={currentModel ? value.reranking_enable : false}
|
||||
onChange={(v) => {
|
||||
onChange({
|
||||
...value,
|
||||
reranking_enable: v,
|
||||
})
|
||||
}}
|
||||
disabled={!currentModel}
|
||||
/>
|
||||
</div>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={value.reranking_enable}
|
||||
onChange={handleDisabledSwitchClick}
|
||||
/>
|
||||
)}
|
||||
<div className='flex items-center'>
|
||||
<span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span>
|
||||
@ -148,21 +132,23 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<ModelSelector
|
||||
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
|
||||
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
|
||||
modelList={rerankModelList}
|
||||
readonly={!value.reranking_enable}
|
||||
onSelect={(v) => {
|
||||
onChange({
|
||||
...value,
|
||||
reranking_model: {
|
||||
reranking_provider_name: v.provider,
|
||||
reranking_model_name: v.model,
|
||||
},
|
||||
})
|
||||
}}
|
||||
/>
|
||||
{
|
||||
value.reranking_enable && (
|
||||
<ModelSelector
|
||||
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
|
||||
modelList={rerankModelList}
|
||||
onSelect={(v) => {
|
||||
onChange({
|
||||
...value,
|
||||
reranking_model: {
|
||||
reranking_provider_name: v.provider,
|
||||
reranking_model_name: v.model,
|
||||
},
|
||||
})
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)}
|
||||
{
|
||||
@ -255,10 +241,8 @@ const RetrievalParamConfig: FC<Props> = ({
|
||||
{
|
||||
value.reranking_mode !== RerankingModeEnum.WeightedScore && (
|
||||
<ModelSelector
|
||||
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
|
||||
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
|
||||
modelList={rerankModelList}
|
||||
readonly={!value.reranking_enable}
|
||||
onSelect={(v) => {
|
||||
onChange({
|
||||
...value,
|
||||
|
@ -30,6 +30,7 @@ import { useProviderContext } from '@/context/provider-context'
|
||||
import { sleep } from '@/utils'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { useInvalidDocumentList } from '@/service/knowledge/use-document'
|
||||
|
||||
type Props = {
|
||||
datasetId: string
|
||||
@ -207,7 +208,9 @@ const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], index
|
||||
})
|
||||
|
||||
const router = useRouter()
|
||||
const invalidDocumentList = useInvalidDocumentList()
|
||||
const navToDocumentList = () => {
|
||||
invalidDocumentList()
|
||||
router.push(`/datasets/${datasetId}/documents`)
|
||||
}
|
||||
const navToApiDocs = () => {
|
||||
|
@ -31,17 +31,17 @@ import LanguageSelect from './language-select'
|
||||
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
|
||||
import { ChunkingMode, DataSourceType, ProcessMode } from '@/models/datasets'
|
||||
|
||||
import Button from '@/app/components/base/button'
|
||||
import FloatRightContainer from '@/app/components/base/float-right-container'
|
||||
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
||||
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
||||
import { type RetrievalConfig } from '@/types/app'
|
||||
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import type { NotionPage } from '@/models/common'
|
||||
import { DataSourceProvider } from '@/models/common'
|
||||
import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets'
|
||||
import { useDatasetDetailContext } from '@/context/dataset-detail'
|
||||
import I18n from '@/context/i18n'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
@ -53,7 +53,7 @@ import type { DefaultModel } from '@/app/components/header/account-setting/model
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import RadioCard from '@/app/components/base/radio-card'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { FULL_DOC_PREVIEW_LENGTH, IS_CE_EDITION } from '@/config'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
@ -90,17 +90,13 @@ type StepTwoProps = {
|
||||
onCancel?: () => void
|
||||
}
|
||||
|
||||
export enum SegmentType {
|
||||
AUTO = 'automatic',
|
||||
CUSTOM = 'custom',
|
||||
}
|
||||
export enum IndexingType {
|
||||
QUALIFIED = 'high_quality',
|
||||
ECONOMICAL = 'economy',
|
||||
}
|
||||
|
||||
const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
|
||||
const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500
|
||||
const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500
|
||||
const DEFAULT_OVERLAP = 50
|
||||
|
||||
type ParentChildConfig = {
|
||||
@ -131,7 +127,6 @@ const StepTwo = ({
|
||||
isSetting,
|
||||
documentDetail,
|
||||
isAPIKeySet,
|
||||
onSetting,
|
||||
datasetId,
|
||||
indexingType,
|
||||
dataSourceType: inCreatePageDataSourceType,
|
||||
@ -162,12 +157,12 @@ const StepTwo = ({
|
||||
|
||||
const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type)
|
||||
const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type
|
||||
const [segmentationType, setSegmentationType] = useState<SegmentType>(SegmentType.CUSTOM)
|
||||
const [segmentationType, setSegmentationType] = useState<ProcessMode>(ProcessMode.general)
|
||||
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
|
||||
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
|
||||
doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER))
|
||||
}, [])
|
||||
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length
|
||||
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH) // default chunk length
|
||||
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000)
|
||||
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
|
||||
const [rules, setRules] = useState<PreProcessingRule[]>([])
|
||||
@ -198,7 +193,6 @@ const StepTwo = ({
|
||||
)
|
||||
|
||||
// QA Related
|
||||
const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false)
|
||||
const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false)
|
||||
const [docForm, setDocForm] = useState<ChunkingMode>(
|
||||
(datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text,
|
||||
@ -348,7 +342,7 @@ const StepTwo = ({
|
||||
}
|
||||
|
||||
const updatePreview = () => {
|
||||
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) {
|
||||
if (segmentationType === ProcessMode.general && maxChunkLength > 4000) {
|
||||
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') })
|
||||
return
|
||||
}
|
||||
@ -373,13 +367,42 @@ const StepTwo = ({
|
||||
model: defaultEmbeddingModel?.model || '',
|
||||
},
|
||||
)
|
||||
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
} as RetrievalConfig)
|
||||
|
||||
useEffect(() => {
|
||||
if (currentDataset?.retrieval_model_dict)
|
||||
return
|
||||
setRetrievalConfig({
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: !!isRerankDefaultModelValid,
|
||||
reranking_model: {
|
||||
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '',
|
||||
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
|
||||
},
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
})
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [rerankDefaultModel, isRerankDefaultModelValid])
|
||||
|
||||
const getCreationParams = () => {
|
||||
let params
|
||||
if (segmentationType === SegmentType.CUSTOM && overlap > maxChunkLength) {
|
||||
if (segmentationType === ProcessMode.general && overlap > maxChunkLength) {
|
||||
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') })
|
||||
return
|
||||
}
|
||||
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) {
|
||||
if (segmentationType === ProcessMode.general && maxChunkLength > limitMaxChunkLength) {
|
||||
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) })
|
||||
return
|
||||
}
|
||||
@ -389,7 +412,6 @@ const StepTwo = ({
|
||||
doc_form: currentDocForm,
|
||||
doc_language: docLanguage,
|
||||
process_rule: getProcessRule(),
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page.
|
||||
embedding_model: embeddingModel.model, // Readonly
|
||||
embedding_model_provider: embeddingModel.provider, // Readonly
|
||||
@ -400,10 +422,7 @@ const StepTwo = ({
|
||||
const indexMethod = getIndexing_technique()
|
||||
if (
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
|
||||
rerankModelList,
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
retrievalConfig,
|
||||
indexMethod: indexMethod as string,
|
||||
})
|
||||
@ -411,16 +430,6 @@ const StepTwo = ({
|
||||
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
|
||||
return
|
||||
}
|
||||
const postRetrievalConfig = ensureRerankModelSelected({
|
||||
rerankDefaultModel: rerankDefaultModel!,
|
||||
retrievalConfig: {
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
...retrievalConfig,
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
|
||||
},
|
||||
indexMethod: indexMethod as string,
|
||||
})
|
||||
params = {
|
||||
data_source: {
|
||||
type: dataSourceType,
|
||||
@ -432,8 +441,7 @@ const StepTwo = ({
|
||||
process_rule: getProcessRule(),
|
||||
doc_form: currentDocForm,
|
||||
doc_language: docLanguage,
|
||||
|
||||
retrieval_model: postRetrievalConfig,
|
||||
retrieval_model: retrievalConfig,
|
||||
embedding_model: embeddingModel.model,
|
||||
embedding_model_provider: embeddingModel.provider,
|
||||
} as CreateDocumentReq
|
||||
@ -490,7 +498,6 @@ const StepTwo = ({
|
||||
|
||||
const getDefaultMode = () => {
|
||||
if (documentDetail)
|
||||
// @ts-expect-error fix after api refactored
|
||||
setSegmentationType(documentDetail.dataset_process_rule.mode)
|
||||
}
|
||||
|
||||
@ -525,7 +532,6 @@ const StepTwo = ({
|
||||
onSuccess(data) {
|
||||
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
|
||||
updateResultCache && updateResultCache(data)
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
|
||||
},
|
||||
},
|
||||
@ -545,14 +551,6 @@ const StepTwo = ({
|
||||
isSetting && onSave && onSave()
|
||||
}
|
||||
|
||||
const changeToEconomicalType = () => {
|
||||
if (docForm !== ChunkingMode.text)
|
||||
return
|
||||
|
||||
if (!hasSetIndexType)
|
||||
setIndexType(IndexingType.ECONOMICAL)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
// fetch rules
|
||||
if (!isSetting) {
|
||||
@ -574,18 +572,6 @@ const StepTwo = ({
|
||||
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
|
||||
}, [isAPIKeySet, indexingType, datasetId])
|
||||
|
||||
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: rerankDefaultModel?.provider.provider,
|
||||
reranking_model_name: rerankDefaultModel?.model,
|
||||
},
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
} as RetrievalConfig)
|
||||
|
||||
const economyDomRef = useRef<HTMLDivElement>(null)
|
||||
const isHoveringEconomy = useHover(economyDomRef)
|
||||
|
||||
@ -946,6 +932,7 @@ const StepTwo = ({
|
||||
<div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div>
|
||||
<ModelSelector
|
||||
readonly={!!datasetId}
|
||||
triggerClassName={datasetId ? 'opacity-50' : ''}
|
||||
defaultModel={embeddingModel}
|
||||
modelList={embeddingModelList}
|
||||
onSelect={(model: DefaultModel) => {
|
||||
@ -984,12 +971,14 @@ const StepTwo = ({
|
||||
getIndexing_technique() === IndexingType.QUALIFIED
|
||||
? (
|
||||
<RetrievalMethodConfig
|
||||
disabled={!!datasetId}
|
||||
value={retrievalConfig}
|
||||
onChange={setRetrievalConfig}
|
||||
/>
|
||||
)
|
||||
: (
|
||||
<EconomicalRetrievalMethodConfig
|
||||
disabled={!!datasetId}
|
||||
value={retrievalConfig}
|
||||
onChange={setRetrievalConfig}
|
||||
/>
|
||||
@ -1010,7 +999,7 @@ const StepTwo = ({
|
||||
)
|
||||
: (
|
||||
<div className='flex items-center mt-8 py-2'>
|
||||
<Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>
|
||||
{!datasetId && <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>}
|
||||
<Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button>
|
||||
</div>
|
||||
)}
|
||||
@ -1081,11 +1070,11 @@ const StepTwo = ({
|
||||
}
|
||||
{
|
||||
currentDocForm !== ChunkingMode.qa
|
||||
&& <Badge text={t(
|
||||
'datasetCreation.stepTwo.previewChunkCount', {
|
||||
count: estimate?.total_segments || 0,
|
||||
}) as string}
|
||||
/>
|
||||
&& <Badge text={t(
|
||||
'datasetCreation.stepTwo.previewChunkCount', {
|
||||
count: estimate?.total_segments || 0,
|
||||
}) as string}
|
||||
/>
|
||||
}
|
||||
</div>
|
||||
</PreviewHeader>}
|
||||
@ -1117,6 +1106,9 @@ const StepTwo = ({
|
||||
{currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && (
|
||||
estimate?.preview?.map((item, index) => {
|
||||
const indexForLabel = index + 1
|
||||
const childChunks = parentChildConfig.chunkForContext === 'full-doc'
|
||||
? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH)
|
||||
: item.child_chunks
|
||||
return (
|
||||
<ChunkContainer
|
||||
key={item.content}
|
||||
@ -1124,7 +1116,7 @@ const StepTwo = ({
|
||||
characterCount={item.content.length}
|
||||
>
|
||||
<FormattedText>
|
||||
{item.child_chunks.map((child, index) => {
|
||||
{childChunks.map((child, index) => {
|
||||
const indexForLabel = index + 1
|
||||
return (
|
||||
<PreviewSlice
|
||||
|
@ -4,7 +4,7 @@ import classNames from '@/utils/classnames'
|
||||
|
||||
const TriangleArrow: FC<ComponentProps<'svg'>> = props => (
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}>
|
||||
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor"/>
|
||||
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor" />
|
||||
</svg>
|
||||
)
|
||||
|
||||
@ -65,7 +65,7 @@ export const OptionCard: FC<OptionCardProps> = forwardRef((props, ref) => {
|
||||
(isActive && !noHighlight)
|
||||
? 'border-[1.5px] border-components-option-card-option-selected-border'
|
||||
: 'border border-components-option-card-option-border',
|
||||
disabled && 'opacity-50 cursor-not-allowed',
|
||||
disabled && 'opacity-50 pointer-events-none',
|
||||
className,
|
||||
)}
|
||||
style={{
|
||||
|
@ -32,6 +32,9 @@ import Checkbox from '@/app/components/base/checkbox'
|
||||
import {
|
||||
useChildSegmentList,
|
||||
useChildSegmentListKey,
|
||||
useChunkListAllKey,
|
||||
useChunkListDisabledKey,
|
||||
useChunkListEnabledKey,
|
||||
useDeleteChildSegment,
|
||||
useDeleteSegment,
|
||||
useDisableSegment,
|
||||
@ -156,18 +159,18 @@ const Completed: FC<ICompletedProps> = ({
|
||||
page: isFullDocMode ? 1 : currentPage,
|
||||
limit: isFullDocMode ? 10 : limit,
|
||||
keyword: isFullDocMode ? '' : searchValue,
|
||||
enabled: selectedStatus === 'all' ? 'all' : !!selectedStatus,
|
||||
enabled: selectedStatus,
|
||||
},
|
||||
},
|
||||
currentPage === 0,
|
||||
)
|
||||
const invalidSegmentList = useInvalid(useSegmentListKey)
|
||||
|
||||
useEffect(() => {
|
||||
if (segmentListData) {
|
||||
setSegments(segmentListData.data || [])
|
||||
if (segmentListData.total_pages < currentPage)
|
||||
setCurrentPage(segmentListData.total_pages)
|
||||
const totalPages = segmentListData.total_pages
|
||||
if (totalPages < currentPage)
|
||||
setCurrentPage(totalPages === 0 ? 1 : totalPages)
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [segmentListData])
|
||||
@ -185,12 +188,12 @@ const Completed: FC<ICompletedProps> = ({
|
||||
documentId,
|
||||
segmentId: segments[0]?.id || '',
|
||||
params: {
|
||||
page: currentPage,
|
||||
page: currentPage === 0 ? 1 : currentPage,
|
||||
limit,
|
||||
keyword: searchValue,
|
||||
},
|
||||
},
|
||||
!isFullDocMode || segments.length === 0 || currentPage === 0,
|
||||
!isFullDocMode || segments.length === 0,
|
||||
)
|
||||
const invalidChildSegmentList = useInvalid(useChildSegmentListKey)
|
||||
|
||||
@ -204,21 +207,20 @@ const Completed: FC<ICompletedProps> = ({
|
||||
useEffect(() => {
|
||||
if (childChunkListData) {
|
||||
setChildSegments(childChunkListData.data || [])
|
||||
if (childChunkListData.total_pages < currentPage)
|
||||
setCurrentPage(childChunkListData.total_pages)
|
||||
const totalPages = childChunkListData.total_pages
|
||||
if (totalPages < currentPage)
|
||||
setCurrentPage(totalPages === 0 ? 1 : totalPages)
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [childChunkListData])
|
||||
|
||||
const resetList = useCallback(() => {
|
||||
setSegments([])
|
||||
setSelectedSegmentIds([])
|
||||
invalidSegmentList()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
const resetChildList = useCallback(() => {
|
||||
setChildSegments([])
|
||||
invalidChildSegmentList()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
@ -232,8 +234,32 @@ const Completed: FC<ICompletedProps> = ({
|
||||
setFullScreen(false)
|
||||
}, [])
|
||||
|
||||
const onCloseNewSegmentModal = useCallback(() => {
|
||||
onNewSegmentModalChange(false)
|
||||
setFullScreen(false)
|
||||
}, [onNewSegmentModalChange])
|
||||
|
||||
const onCloseNewChildChunkModal = useCallback(() => {
|
||||
setShowNewChildSegmentModal(false)
|
||||
setFullScreen(false)
|
||||
}, [])
|
||||
|
||||
const { mutateAsync: enableSegment } = useEnableSegment()
|
||||
const { mutateAsync: disableSegment } = useDisableSegment()
|
||||
const invalidChunkListAll = useInvalid(useChunkListAllKey)
|
||||
const invalidChunkListEnabled = useInvalid(useChunkListEnabledKey)
|
||||
const invalidChunkListDisabled = useInvalid(useChunkListDisabledKey)
|
||||
|
||||
const refreshChunkListWithStatusChanged = () => {
|
||||
switch (selectedStatus) {
|
||||
case 'all':
|
||||
invalidChunkListDisabled()
|
||||
invalidChunkListEnabled()
|
||||
break
|
||||
default:
|
||||
invalidSegmentList()
|
||||
}
|
||||
}
|
||||
|
||||
const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => {
|
||||
const operationApi = enable ? enableSegment : disableSegment
|
||||
@ -245,6 +271,7 @@ const Completed: FC<ICompletedProps> = ({
|
||||
seg.enabled = enable
|
||||
}
|
||||
setSegments([...segments])
|
||||
refreshChunkListWithStatusChanged()
|
||||
},
|
||||
onError: () => {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
@ -271,6 +298,23 @@ const Completed: FC<ICompletedProps> = ({
|
||||
|
||||
const { mutateAsync: updateSegment } = useUpdateSegment()
|
||||
|
||||
const refreshChunkListDataWithDetailChanged = () => {
|
||||
switch (selectedStatus) {
|
||||
case 'all':
|
||||
invalidChunkListDisabled()
|
||||
invalidChunkListEnabled()
|
||||
break
|
||||
case true:
|
||||
invalidChunkListAll()
|
||||
invalidChunkListDisabled()
|
||||
break
|
||||
case false:
|
||||
invalidChunkListAll()
|
||||
invalidChunkListEnabled()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
const handleUpdateSegment = useCallback(async (
|
||||
segmentId: string,
|
||||
question: string,
|
||||
@ -320,6 +364,7 @@ const Completed: FC<ICompletedProps> = ({
|
||||
}
|
||||
}
|
||||
setSegments([...segments])
|
||||
refreshChunkListDataWithDetailChanged()
|
||||
eventEmitter?.emit('update-segment-success')
|
||||
},
|
||||
onSettled() {
|
||||
@ -432,6 +477,7 @@ const Completed: FC<ICompletedProps> = ({
|
||||
seg.child_chunks?.push(newChildChunk!)
|
||||
}
|
||||
setSegments([...segments])
|
||||
refreshChunkListDataWithDetailChanged()
|
||||
}
|
||||
else {
|
||||
resetChildList()
|
||||
@ -496,17 +542,10 @@ const Completed: FC<ICompletedProps> = ({
|
||||
}
|
||||
}
|
||||
setSegments([...segments])
|
||||
refreshChunkListDataWithDetailChanged()
|
||||
}
|
||||
else {
|
||||
for (const childSeg of childSegments) {
|
||||
if (childSeg.id === childChunkId) {
|
||||
childSeg.content = res.data.content
|
||||
childSeg.type = res.data.type
|
||||
childSeg.word_count = res.data.word_count
|
||||
childSeg.updated_at = res.data.updated_at
|
||||
}
|
||||
}
|
||||
setChildSegments([...childSegments])
|
||||
resetChildList()
|
||||
}
|
||||
},
|
||||
onSettled: () => {
|
||||
@ -544,12 +583,13 @@ const Completed: FC<ICompletedProps> = ({
|
||||
<SimpleSelect
|
||||
onSelect={onChangeStatus}
|
||||
items={statusList.current}
|
||||
defaultValue={'all'}
|
||||
defaultValue={selectedStatus === 'all' ? 'all' : selectedStatus ? 1 : 0}
|
||||
className={s.select}
|
||||
wrapperClassName='h-fit mr-2'
|
||||
optionWrapClassName='w-[160px]'
|
||||
optionClassName='p-0'
|
||||
renderOption={({ item, selected }) => <StatusItem item={item} selected={selected} />}
|
||||
notClearable
|
||||
/>
|
||||
<Input
|
||||
showLeftIcon
|
||||
@ -623,6 +663,7 @@ const Completed: FC<ICompletedProps> = ({
|
||||
<FullScreenDrawer
|
||||
isOpen={currSegment.showModal}
|
||||
fullScreen={fullScreen}
|
||||
onClose={onCloseSegmentDetail}
|
||||
>
|
||||
<SegmentDetail
|
||||
segInfo={currSegment.segInfo ?? { id: '' }}
|
||||
@ -636,13 +677,11 @@ const Completed: FC<ICompletedProps> = ({
|
||||
<FullScreenDrawer
|
||||
isOpen={showNewSegmentModal}
|
||||
fullScreen={fullScreen}
|
||||
onClose={onCloseNewSegmentModal}
|
||||
>
|
||||
<NewSegment
|
||||
docForm={docForm}
|
||||
onCancel={() => {
|
||||
onNewSegmentModalChange(false)
|
||||
setFullScreen(false)
|
||||
}}
|
||||
onCancel={onCloseNewSegmentModal}
|
||||
onSave={resetList}
|
||||
viewNewlyAddedChunk={viewNewlyAddedChunk}
|
||||
/>
|
||||
@ -651,6 +690,7 @@ const Completed: FC<ICompletedProps> = ({
|
||||
<FullScreenDrawer
|
||||
isOpen={currChildChunk.showModal}
|
||||
fullScreen={fullScreen}
|
||||
onClose={onCloseChildSegmentDetail}
|
||||
>
|
||||
<ChildSegmentDetail
|
||||
chunkId={currChunkId}
|
||||
@ -664,13 +704,11 @@ const Completed: FC<ICompletedProps> = ({
|
||||
<FullScreenDrawer
|
||||
isOpen={showNewChildSegmentModal}
|
||||
fullScreen={fullScreen}
|
||||
onClose={onCloseNewChildChunkModal}
|
||||
>
|
||||
<NewChildSegment
|
||||
chunkId={currChunkId}
|
||||
onCancel={() => {
|
||||
setShowNewChildSegmentModal(false)
|
||||
setFullScreen(false)
|
||||
}}
|
||||
onCancel={onCloseNewChildChunkModal}
|
||||
onSave={onSaveNewChildChunk}
|
||||
viewNewlyAddedChildChunk={viewNewlyAddedChildChunk}
|
||||
/>
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user