mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 18:55:55 +08:00
Merge branch 'main' into fix/chore-fix
This commit is contained in:
commit
fb309462ad
@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
|
|||||||
# Access token expiration time in minutes
|
# Access token expiration time in minutes
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||||
|
|
||||||
|
# Refresh token expiration time in days
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
# celery configuration
|
# celery configuration
|
||||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||||
|
|
||||||
|
@ -14,7 +14,10 @@ if is_db_command():
|
|||||||
|
|
||||||
app = create_migrations_app()
|
app = create_migrations_app()
|
||||||
else:
|
else:
|
||||||
if os.environ.get("FLASK_DEBUG", "False") != "True":
|
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||||
|
# so we need to disable gevent in debug mode.
|
||||||
|
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||||
|
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||||
from gevent import monkey # type: ignore
|
from gevent import monkey # type: ignore
|
||||||
|
|
||||||
# gevent
|
# gevent
|
||||||
|
@ -546,6 +546,11 @@ class AuthConfig(BaseSettings):
|
|||||||
default=60,
|
default=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
|
||||||
|
description="Expiration time for refresh tokens in days",
|
||||||
|
default=30,
|
||||||
|
)
|
||||||
|
|
||||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||||
default=86400,
|
default=86400,
|
||||||
@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings):
|
|||||||
default=4000,
|
default=4000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
|
||||||
|
description="Maximum number of child chunks to preview",
|
||||||
|
default=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalTransferConfig(BaseSettings):
|
class MultiModalTransferConfig(BaseSettings):
|
||||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||||
|
@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
|
|||||||
description="Name of the Milvus database to connect to (default is 'default')",
|
description="Name of the Milvus database to connect to (default is 'default')",
|
||||||
default="default",
|
default="default",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
|
||||||
|
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
|
||||||
|
"older versions",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||||||
|
|
||||||
CURRENT_VERSION: str = Field(
|
CURRENT_VERSION: str = Field(
|
||||||
description="Dify version",
|
description="Dify version",
|
||||||
default="0.14.2",
|
default="0.15.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
COMMIT_SHA: str = Field(
|
COMMIT_SHA: str = Field(
|
||||||
|
@ -57,12 +57,13 @@ class AppListApi(Resource):
|
|||||||
)
|
)
|
||||||
parser.add_argument("name", type=str, location="args", required=False)
|
parser.add_argument("name", type=str, location="args", required=False)
|
||||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||||
|
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# get app list
|
# get app list
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
||||||
if not app_pagination:
|
if not app_pagination:
|
||||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import (
|
from core.errors.error import (
|
||||||
AppInvokeQuotaExceededError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
ModelCurrentlyNotSupportError,
|
||||||
ProviderTokenNotInitError,
|
ProviderTokenNotInitError,
|
||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
|
|||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
|
|||||||
raise InvokeRateLimitHttpError(ex.description)
|
raise InvokeRateLimitHttpError(ex.description)
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
@ -273,8 +273,7 @@ FROM
|
|||||||
messages m
|
messages m
|
||||||
ON c.id = m.conversation_id
|
ON c.id = m.conversation_id
|
||||||
WHERE
|
WHERE
|
||||||
c.override_model_configs IS NULL
|
c.app_id = :app_id"""
|
||||||
AND c.app_id = :app_id"""
|
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
|
@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
| VectorType.MYSCALE
|
| VectorType.MYSCALE
|
||||||
| VectorType.ORACLE
|
| VectorType.ORACLE
|
||||||
| VectorType.ELASTICSEARCH
|
| VectorType.ELASTICSEARCH
|
||||||
|
| VectorType.ELASTICSEARCH_JA
|
||||||
| VectorType.PGVECTOR
|
| VectorType.PGVECTOR
|
||||||
| VectorType.TIDB_ON_QDRANT
|
| VectorType.TIDB_ON_QDRANT
|
||||||
| VectorType.LINDORM
|
| VectorType.LINDORM
|
||||||
@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
| VectorType.MYSCALE
|
| VectorType.MYSCALE
|
||||||
| VectorType.ORACLE
|
| VectorType.ORACLE
|
||||||
| VectorType.ELASTICSEARCH
|
| VectorType.ELASTICSEARCH
|
||||||
|
| VectorType.ELASTICSEARCH_JA
|
||||||
| VectorType.COUCHBASE
|
| VectorType.COUCHBASE
|
||||||
| VectorType.PGVECTOR
|
| VectorType.PGVECTOR
|
||||||
| VectorType.LINDORM
|
| VectorType.LINDORM
|
||||||
|
@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource):
|
|||||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||||
|
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
|
|||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
|
|||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
|
@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import (
|
from core.errors.error import (
|
||||||
AppInvokeQuotaExceededError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
ModelCurrentlyNotSupportError,
|
||||||
ProviderTokenNotInitError,
|
ProviderTokenNotInitError,
|
||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
@ -74,7 +73,7 @@ class CompletionApi(Resource):
|
|||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
@ -133,7 +132,7 @@ class ChatApi(Resource):
|
|||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import (
|
from core.errors.error import (
|
||||||
AppInvokeQuotaExceededError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
ModelCurrentlyNotSupportError,
|
||||||
ProviderTokenNotInitError,
|
ProviderTokenNotInitError,
|
||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
|
|||||||
raise ProviderModelCurrentlyNotSupportError()
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
except InvokeError as e:
|
except InvokeError as e:
|
||||||
raise CompletionRequestError(e.description)
|
raise CompletionRequestError(e.description)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
except ValueError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("internal server error.")
|
logging.exception("internal server error.")
|
||||||
|
@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
source="datasets",
|
source="datasets",
|
||||||
)
|
)
|
||||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
data_source = {
|
||||||
|
"type": "upload_file",
|
||||||
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
|
}
|
||||||
args["data_source"] = data_source
|
args["data_source"] = data_source
|
||||||
# validate args
|
# validate args
|
||||||
knowledge_config = KnowledgeConfig(**args)
|
knowledge_config = KnowledgeConfig(**args)
|
||||||
@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
data_source = {
|
||||||
|
"type": "upload_file",
|
||||||
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
|
}
|
||||||
args["data_source"] = data_source
|
args["data_source"] = data_source
|
||||||
# validate args
|
# validate args
|
||||||
args["original_document_id"] = str(document_id)
|
args["original_document_id"] = str(document_id)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime, timedelta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -8,6 +8,8 @@ from flask import current_app, request
|
|||||||
from flask_login import user_logged_in # type: ignore
|
from flask_login import user_logged_in # type: ignore
|
||||||
from flask_restful import Resource # type: ignore
|
from flask_restful import Resource # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def validate_and_get_api_token(scope=None):
|
def validate_and_get_api_token(scope: str | None = None):
|
||||||
"""
|
"""
|
||||||
Validate and get API token.
|
Validate and get API token.
|
||||||
"""
|
"""
|
||||||
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
|
|||||||
if auth_scheme != "bearer":
|
if auth_scheme != "bearer":
|
||||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||||
|
|
||||||
api_token = (
|
current_time = datetime.now(UTC).replace(tzinfo=None)
|
||||||
db.session.query(ApiToken)
|
cutoff_time = current_time - timedelta(minutes=1)
|
||||||
.filter(
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
ApiToken.token == auth_token,
|
update_stmt = (
|
||||||
ApiToken.type == scope,
|
update(ApiToken)
|
||||||
|
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
||||||
|
.values(last_used_at=current_time)
|
||||||
|
.returning(ApiToken)
|
||||||
)
|
)
|
||||||
.first()
|
result = session.execute(update_stmt)
|
||||||
)
|
api_token = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not api_token:
|
if not api_token:
|
||||||
raise Unauthorized("Access token is invalid")
|
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||||
|
api_token = session.scalar(stmt)
|
||||||
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
if not api_token:
|
||||||
db.session.commit()
|
raise Unauthorized("Access token is invalid")
|
||||||
|
else:
|
||||||
|
session.commit()
|
||||||
|
|
||||||
return api_token
|
return api_token
|
||||||
|
|
||||||
|
@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
|||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
|
@ -14,7 +14,11 @@ from controllers.web.error import (
|
|||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
|
@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
|||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
@ -68,24 +68,17 @@ from models.account import Account
|
|||||||
from models.enums import CreatedByRole
|
from models.enums import CreatedByRole
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNodeExecution,
|
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
class AdvancedChatAppGenerateTaskPipeline:
|
||||||
"""
|
"""
|
||||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: WorkflowTaskState
|
|
||||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
||||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
|
||||||
_conversation_name_generate_thread: Optional[Thread] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
stream: bool,
|
stream: bool,
|
||||||
dialogue_count: int,
|
dialogue_count: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||||
|
|
||||||
self._workflow_id = workflow.id
|
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||||
self._workflow_features_dict = workflow.features_dict
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_system_variables={
|
||||||
self._conversation_id = conversation.id
|
SystemVariableKey.QUERY: message.query,
|
||||||
self._conversation_mode = conversation.mode
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
|
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||||
self._message_id = message.id
|
SystemVariableKey.USER_ID: user_session_id,
|
||||||
self._message_created_at = int(message.created_at.timestamp())
|
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||||
|
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||||
self._workflow_system_variables = {
|
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||||
SystemVariableKey.QUERY: message.query,
|
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
},
|
||||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
)
|
||||||
SystemVariableKey.USER_ID: user_session_id,
|
|
||||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
|
||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
|
||||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
self._wip_workflow_node_executions = {}
|
self._message_cycle_manager = MessageCycleManage(
|
||||||
self._wip_workflow_agent_logs = {}
|
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||||
|
)
|
||||||
|
|
||||||
self._conversation_name_generate_thread = None
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._workflow_id = workflow.id
|
||||||
|
self._workflow_features_dict = workflow.features_dict
|
||||||
|
self._conversation_id = conversation.id
|
||||||
|
self._conversation_mode = conversation.mode
|
||||||
|
self._message_id = message.id
|
||||||
|
self._message_created_at = int(message.created_at.timestamp())
|
||||||
|
self._conversation_name_generate_thread: Thread | None = None
|
||||||
self._recorded_files: list[Mapping[str, Any]] = []
|
self._recorded_files: list[Mapping[str, Any]] = []
|
||||||
self._workflow_run_id = ""
|
self._workflow_run_id: str = ""
|
||||||
|
|
||||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||||
"""
|
"""
|
||||||
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# start generate conversation name thread
|
# start generate conversation name thread
|
||||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
|
|
||||||
if self._stream:
|
if self._base_task_pipeline._stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
# init fake graph runtime state
|
# init fake graph runtime state
|
||||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||||
|
|
||||||
for queue_message in self._queue_manager.listen():
|
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||||
event = queue_message.event
|
event = queue_message.event
|
||||||
|
|
||||||
if isinstance(event, QueuePingEvent):
|
if isinstance(event, QueuePingEvent):
|
||||||
yield self._ping_stream_response()
|
yield self._base_task_pipeline._ping_stream_response()
|
||||||
elif isinstance(event, QueueErrorEvent):
|
elif isinstance(event, QueueErrorEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
err = self._base_task_pipeline._handle_error(
|
||||||
|
event=event, session=session, message_id=self._message_id
|
||||||
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||||
# override graph runtime state
|
# override graph runtime state
|
||||||
graph_runtime_state = event.graph_runtime_state
|
graph_runtime_state = event.graph_runtime_state
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run = self._handle_workflow_run_start(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_id=self._workflow_id,
|
workflow_id=self._workflow_id,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not message:
|
if not message:
|
||||||
raise ValueError(f"Message not found: {self._message_id}")
|
raise ValueError(f"Message not found: {self._message_id}")
|
||||||
message.workflow_run_id = workflow_run.id
|
message.workflow_run_id = workflow_run.id
|
||||||
workflow_start_resp = self._workflow_start_to_stream_response(
|
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_retry_resp = self._workflow_node_retry_to_stream_response(
|
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_node_execution_start(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
node_start_resp = self._workflow_node_start_to_stream_response(
|
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
# Record files if it's an answer node or end node
|
# Record files if it's an answer node or end node
|
||||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
self._recorded_files.extend(
|
||||||
|
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
|
||||||
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
|
session=session, event=event
|
||||||
|
)
|
||||||
|
|
||||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -368,10 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if node_finish_resp:
|
if node_finish_resp:
|
||||||
yield node_finish_resp
|
yield node_finish_resp
|
||||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
|
session=session, event=event
|
||||||
|
)
|
||||||
|
|
||||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -385,13 +392,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
session=session,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
parallel_start_resp = (
|
||||||
workflow_run=workflow_run,
|
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||||
event=event,
|
session=session,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield parallel_start_resp
|
yield parallel_start_resp
|
||||||
@ -399,13 +410,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
session=session,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
parallel_finish_resp = (
|
||||||
workflow_run=workflow_run,
|
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||||
event=event,
|
session=session,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield parallel_finish_resp
|
yield parallel_finish_resp
|
||||||
@ -413,9 +428,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -427,9 +444,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -441,9 +460,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -458,8 +479,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -470,21 +491,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
|
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_partial_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -495,21 +518,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
|
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -521,20 +546,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
exceptions_count=event.exceptions_count,
|
exceptions_count=event.exceptions_count,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||||
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
|
err = self._base_task_pipeline._handle_error(
|
||||||
|
event=err_event, session=session, message_id=self._message_id
|
||||||
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueStopEvent):
|
elif isinstance(event, QueueStopEvent):
|
||||||
if self._workflow_run_id and graph_runtime_state:
|
if self._workflow_run_id and graph_runtime_state:
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -545,7 +572,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
conversation_id=self._conversation_id,
|
conversation_id=self._conversation_id,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -559,18 +586,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
yield self._message_end_to_stream_response()
|
yield self._message_end_to_stream_response()
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
self._handle_retriever_resources(event)
|
self._message_cycle_manager._handle_retriever_resources(event)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
self._handle_annotation_reply(event)
|
self._message_cycle_manager._handle_annotation_reply(event)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
@ -591,23 +618,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
tts_publisher.publish(queue_message)
|
tts_publisher.publish(queue_message)
|
||||||
|
|
||||||
self._task_state.answer += delta_text
|
self._task_state.answer += delta_text
|
||||||
yield self._message_to_stream_response(
|
yield self._message_cycle_manager._message_to_stream_response(
|
||||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueMessageReplaceEvent):
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
# published by moderation
|
# published by moderation
|
||||||
yield self._message_replace_to_stream_response(answer=event.text)
|
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
|
||||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||||
|
self._task_state.answer
|
||||||
|
)
|
||||||
if output_moderation_answer:
|
if output_moderation_answer:
|
||||||
self._task_state.answer = output_moderation_answer
|
self._task_state.answer = output_moderation_answer
|
||||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||||
|
answer=output_moderation_answer
|
||||||
|
)
|
||||||
|
|
||||||
# Save message
|
# Save message
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@ -627,7 +658,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.answer = self._task_state.answer
|
message.answer = self._task_state.answer
|
||||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||||
message.message_metadata = (
|
message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
)
|
)
|
||||||
@ -691,20 +722,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
:param text: text
|
:param text: text
|
||||||
:return: True if output moderation should direct output, otherwise False
|
:return: True if output moderation should direct output, otherwise False
|
||||||
"""
|
"""
|
||||||
if self._output_moderation_handler:
|
if self._base_task_pipeline._output_moderation_handler:
|
||||||
if self._output_moderation_handler.should_direct_output():
|
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||||
# stop subscribe new token when output moderation should direct output
|
# stop subscribe new token when output moderation should direct output
|
||||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||||
self._queue_manager.publish(
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||||
)
|
)
|
||||||
|
|
||||||
self._queue_manager.publish(
|
self._base_task_pipeline._queue_manager.publish(
|
||||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
self._output_moderation_handler.append_new_token(text)
|
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
|||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
|||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
|||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
|||||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||||
node_id=node_id, inputs=args["inputs"]
|
node_id=node_id, inputs=args["inputs"]
|
||||||
),
|
),
|
||||||
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except ValueError as e:
|
||||||
if dify_config.DEBUG:
|
if dify_config.DEBUG:
|
||||||
logger.exception("Error when generating")
|
logger.exception("Error when generating")
|
||||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -59,7 +59,6 @@ from models.workflow import (
|
|||||||
Workflow,
|
Workflow,
|
||||||
WorkflowAppLog,
|
WorkflowAppLog,
|
||||||
WorkflowAppLogCreatedFrom,
|
WorkflowAppLogCreatedFrom,
|
||||||
WorkflowNodeExecution,
|
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
@ -67,16 +66,11 @@ from models.workflow import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
|
class WorkflowAppGenerateTaskPipeline:
|
||||||
"""
|
"""
|
||||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: WorkflowTaskState
|
|
||||||
_application_generate_entity: WorkflowAppGenerateEntity
|
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
||||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
stream: bool,
|
stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid user type: {type(user)}")
|
raise ValueError(f"Invalid user type: {type(user)}")
|
||||||
|
|
||||||
|
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_system_variables={
|
||||||
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
|
SystemVariableKey.USER_ID: user_session_id,
|
||||||
|
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||||
|
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||||
|
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
self._workflow_id = workflow.id
|
self._workflow_id = workflow.id
|
||||||
self._workflow_features_dict = workflow.features_dict
|
self._workflow_features_dict = workflow.features_dict
|
||||||
|
|
||||||
self._workflow_system_variables = {
|
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
|
||||||
SystemVariableKey.USER_ID: user_session_id,
|
|
||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
|
||||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
self._workflow_run_id = ""
|
self._workflow_run_id = ""
|
||||||
|
|
||||||
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||||
if self._stream:
|
if self._base_task_pipeline._stream:
|
||||||
return self._to_stream_response(generator)
|
return self._to_stream_response(generator)
|
||||||
else:
|
else:
|
||||||
return self._to_blocking_response(generator)
|
return self._to_blocking_response(generator)
|
||||||
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
"""
|
"""
|
||||||
graph_runtime_state = None
|
graph_runtime_state = None
|
||||||
|
|
||||||
for queue_message in self._queue_manager.listen():
|
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||||
event = queue_message.event
|
event = queue_message.event
|
||||||
|
|
||||||
if isinstance(event, QueuePingEvent):
|
if isinstance(event, QueuePingEvent):
|
||||||
yield self._ping_stream_response()
|
yield self._base_task_pipeline._ping_stream_response()
|
||||||
elif isinstance(event, QueueErrorEvent):
|
elif isinstance(event, QueueErrorEvent):
|
||||||
err = self._handle_error(event=event)
|
err = self._base_task_pipeline._handle_error(event=event)
|
||||||
yield self._error_to_stream_response(err)
|
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||||
# override graph runtime state
|
# override graph runtime state
|
||||||
graph_runtime_state = event.graph_runtime_state
|
graph_runtime_state = event.graph_runtime_state
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run = self._handle_workflow_run_start(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_id=self._workflow_id,
|
workflow_id=self._workflow_id,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
created_by_role=self._created_by_role,
|
created_by_role=self._created_by_role,
|
||||||
)
|
)
|
||||||
self._workflow_run_id = workflow_run.id
|
self._workflow_run_id = workflow_run.id
|
||||||
start_resp = self._workflow_start_to_stream_response(
|
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
):
|
):
|
||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
response = self._workflow_node_retry_to_stream_response(
|
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
workflow_node_execution = self._handle_node_execution_start(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
session=session, workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_start_response = self._workflow_node_start_to_stream_response(
|
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if node_start_response:
|
if node_start_response:
|
||||||
yield node_start_response
|
yield node_start_response
|
||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
session=session, event=event
|
||||||
|
)
|
||||||
|
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if node_success_response:
|
if node_success_response:
|
||||||
yield node_success_response
|
yield node_success_response
|
||||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_failed(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
session=session,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
parallel_start_resp = (
|
||||||
workflow_run=workflow_run,
|
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||||
event=event,
|
session=session,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield parallel_start_resp
|
yield parallel_start_resp
|
||||||
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
session=session,
|
)
|
||||||
task_id=self._application_generate_entity.task_id,
|
parallel_finish_resp = (
|
||||||
workflow_run=workflow_run,
|
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||||
event=event,
|
session=session,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield parallel_finish_resp
|
yield parallel_finish_resp
|
||||||
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not self._workflow_run_id:
|
if not self._workflow_run_id:
|
||||||
raise ValueError("workflow run not initialized.")
|
raise ValueError("workflow run not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
|
)
|
||||||
|
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session,
|
session=session,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_partial_success(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise ValueError("graph runtime state not initialized.")
|
raise ValueError("graph runtime state not initialized.")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
session=session,
|
session=session,
|
||||||
workflow_run_id=self._workflow_run_id,
|
workflow_run_id=self._workflow_run_id,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||||
|
|
||||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||||||
|
|
||||||
# app config
|
# app config
|
||||||
app_config: WorkflowUIBasedAppConfig
|
app_config: WorkflowUIBasedAppConfig
|
||||||
workflow_run_id: Optional[str] = None
|
workflow_run_id: str
|
||||||
|
|
||||||
class SingleIterationRunEntity(BaseModel):
|
class SingleIterationRunEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
|||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
ErrorStreamResponse,
|
ErrorStreamResponse,
|
||||||
PingStreamResponse,
|
PingStreamResponse,
|
||||||
TaskState,
|
|
||||||
)
|
)
|
||||||
from core.errors.error import QuotaExceededError
|
from core.errors.error import QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
|
|||||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_task_state: TaskState
|
|
||||||
_application_generate_entity: AppGenerateEntity
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application_generate_entity: AppGenerateEntity,
|
application_generate_entity: AppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Initialize GenerateTaskPipeline.
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param user: user
|
|
||||||
:param stream: stream
|
|
||||||
"""
|
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
self._start_at = time.perf_counter()
|
self._start_at = time.perf_counter()
|
||||||
|
@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
|
|||||||
|
|
||||||
|
|
||||||
class MessageCycleManage:
|
class MessageCycleManage:
|
||||||
_application_generate_entity: Union[
|
def __init__(
|
||||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
self,
|
||||||
]
|
*,
|
||||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
application_generate_entity: Union[
|
||||||
|
ChatAppGenerateEntity,
|
||||||
|
CompletionAppGenerateEntity,
|
||||||
|
AgentChatAppGenerateEntity,
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
],
|
||||||
|
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||||
|
) -> None:
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._task_state = task_state
|
||||||
|
|
||||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||||
"""
|
"""
|
||||||
|
@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
|
|||||||
ParallelBranchStartStreamResponse,
|
ParallelBranchStartStreamResponse,
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
WorkflowTaskState,
|
|
||||||
)
|
)
|
||||||
from core.file import FILE_MODEL_IDENTITY, File
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
@ -60,13 +59,20 @@ from models.workflow import (
|
|||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
|
from .exc import WorkflowRunNotFoundError
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCycleManage:
|
class WorkflowCycleManage:
|
||||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
def __init__(
|
||||||
_task_state: WorkflowTaskState
|
self,
|
||||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
*,
|
||||||
|
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||||
|
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||||
|
) -> None:
|
||||||
|
self._workflow_run: WorkflowRun | None = None
|
||||||
|
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||||
|
self._application_generate_entity = application_generate_entity
|
||||||
|
self._workflow_system_variables = workflow_system_variables
|
||||||
|
|
||||||
def _handle_workflow_run_start(
|
def _handle_workflow_run_start(
|
||||||
self,
|
self,
|
||||||
@ -104,7 +110,8 @@ class WorkflowCycleManage:
|
|||||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||||
|
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
|
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||||
|
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||||
|
|
||||||
workflow_run = WorkflowRun()
|
workflow_run = WorkflowRun()
|
||||||
workflow_run.id = workflow_run_id
|
workflow_run.id = workflow_run_id
|
||||||
@ -241,7 +248,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_run.exceptions_count = exceptions_count
|
workflow_run.exceptions_count = exceptions_count
|
||||||
|
|
||||||
stmt = select(WorkflowNodeExecution).where(
|
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
||||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||||
@ -249,15 +256,18 @@ class WorkflowCycleManage:
|
|||||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||||
)
|
)
|
||||||
|
ids = session.scalars(stmt).all()
|
||||||
running_workflow_node_executions = session.scalars(stmt).all()
|
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
||||||
|
running_workflow_node_executions = [
|
||||||
|
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
||||||
|
]
|
||||||
|
|
||||||
for workflow_node_execution in running_workflow_node_executions:
|
for workflow_node_execution in running_workflow_node_executions:
|
||||||
|
now = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
workflow_node_execution.error = error
|
workflow_node_execution.error = error
|
||||||
finish_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.finished_at = now
|
||||||
workflow_node_execution.finished_at = finish_at
|
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||||
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
|
|
||||||
|
|
||||||
if trace_manager:
|
if trace_manager:
|
||||||
trace_manager.add_trace_task(
|
trace_manager.add_trace_task(
|
||||||
@ -299,6 +309,8 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
session.add(workflow_node_execution)
|
||||||
|
|
||||||
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_success(
|
def _handle_workflow_node_execution_success(
|
||||||
@ -325,6 +337,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.finished_at = finished_at
|
workflow_node_execution.finished_at = finished_at
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
|
|
||||||
|
workflow_node_execution = session.merge(workflow_node_execution)
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_failed(
|
def _handle_workflow_node_execution_failed(
|
||||||
@ -364,6 +377,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
|
|
||||||
|
workflow_node_execution = session.merge(workflow_node_execution)
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_retried(
|
def _handle_workflow_node_execution_retried(
|
||||||
@ -415,6 +429,8 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.index = event.node_run_index
|
workflow_node_execution.index = event.node_run_index
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
session.add(workflow_node_execution)
|
||||||
|
|
||||||
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
@ -811,25 +827,23 @@ class WorkflowCycleManage:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||||
"""
|
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||||
Refetch workflow run
|
cached_workflow_run = self._workflow_run
|
||||||
:param workflow_run_id: workflow run id
|
cached_workflow_run = session.merge(cached_workflow_run)
|
||||||
:return:
|
return cached_workflow_run
|
||||||
"""
|
|
||||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||||
workflow_run = session.scalar(stmt)
|
workflow_run = session.scalar(stmt)
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||||
|
self._workflow_run = workflow_run
|
||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
||||||
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
|
if node_execution_id not in self._workflow_node_executions:
|
||||||
workflow_node_execution = session.scalar(stmt)
|
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||||
if not workflow_node_execution:
|
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||||
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
|
return cached_workflow_node_execution
|
||||||
|
|
||||||
return workflow_node_execution
|
|
||||||
|
|
||||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||||
"""
|
"""
|
||||||
|
@ -1,9 +1,47 @@
|
|||||||
import tiktoken
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_tokenizer: Any = None
|
||||||
|
_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
class GPT2Tokenizer:
|
class GPT2Tokenizer:
|
||||||
|
@staticmethod
|
||||||
|
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||||
|
"""
|
||||||
|
use gpt2 tokenizer to get num tokens
|
||||||
|
"""
|
||||||
|
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||||
|
tokens = _tokenizer.encode(text)
|
||||||
|
return len(tokens)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_num_tokens(text: str) -> int:
|
def get_num_tokens(text: str) -> int:
|
||||||
encoding = tiktoken.encoding_for_model("gpt2")
|
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||||
tiktoken_vec = encoding.encode(text)
|
#
|
||||||
return len(tiktoken_vec)
|
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||||
|
# result = future.result()
|
||||||
|
# return cast(int, result)
|
||||||
|
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_encoder() -> Any:
|
||||||
|
global _tokenizer, _lock
|
||||||
|
with _lock:
|
||||||
|
if _tokenizer is None:
|
||||||
|
# Try to use tiktoken to get the tokenizer because it is faster
|
||||||
|
#
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
except Exception:
|
||||||
|
from os.path import abspath, dirname, join
|
||||||
|
|
||||||
|
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||||
|
|
||||||
|
base_path = abspath(__file__)
|
||||||
|
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||||
|
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||||
|
|
||||||
|
return _tokenizer
|
||||||
|
@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
quoted_ids = [f"'{id}'" for id in ids]
|
quoted_ids = [f"'{id}'" for id in ids]
|
||||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||||
|
|
||||||
|
@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
|
|||||||
self._client.delete_collection(self._collection_name)
|
self._client.delete_collection(self._collection_name)
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
collection = self._client.get_or_create_collection(self._collection_name)
|
collection = self._client.get_or_create_collection(self._collection_name)
|
||||||
collection.delete(ids=ids)
|
collection.delete(ids=ids)
|
||||||
|
|
||||||
|
@ -0,0 +1,104 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
|
||||||
|
ElasticSearchConfig,
|
||||||
|
ElasticSearchVector,
|
||||||
|
ElasticSearchVectorFactory,
|
||||||
|
)
|
||||||
|
from core.rag.datasource.vdb.field import Field
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticSearchJaVector(ElasticSearchVector):
|
||||||
|
def create_collection(
|
||||||
|
self,
|
||||||
|
embeddings: list[list[float]],
|
||||||
|
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||||
|
index_params: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
logger.info(f"Collection {self._collection_name} already exists.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
settings = {
|
||||||
|
"analysis": {
|
||||||
|
"analyzer": {
|
||||||
|
"ja_analyzer": {
|
||||||
|
"type": "custom",
|
||||||
|
"char_filter": [
|
||||||
|
"icu_normalizer",
|
||||||
|
"kuromoji_iteration_mark",
|
||||||
|
],
|
||||||
|
"tokenizer": "kuromoji_tokenizer",
|
||||||
|
"filter": [
|
||||||
|
"kuromoji_baseform",
|
||||||
|
"kuromoji_part_of_speech",
|
||||||
|
"ja_stop",
|
||||||
|
"kuromoji_number",
|
||||||
|
"kuromoji_stemmer",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mappings = {
|
||||||
|
"properties": {
|
||||||
|
Field.CONTENT_KEY.value: {
|
||||||
|
"type": "text",
|
||||||
|
"analyzer": "ja_analyzer",
|
||||||
|
"search_analyzer": "ja_analyzer",
|
||||||
|
},
|
||||||
|
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": dim,
|
||||||
|
"index": True,
|
||||||
|
"similarity": "cosine",
|
||||||
|
},
|
||||||
|
Field.METADATA_KEY.value: {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
|
||||||
|
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
return ElasticSearchJaVector(
|
||||||
|
index_name=collection_name,
|
||||||
|
config=ElasticSearchConfig(
|
||||||
|
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||||
|
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||||
|
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||||
|
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||||
|
),
|
||||||
|
attributes=[],
|
||||||
|
)
|
@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
return bool(self._client.exists(index=self._collection_name, id=id))
|
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
for id in ids:
|
for id in ids:
|
||||||
self._client.delete(index=self._collection_name, id=id)
|
self._client.delete(index=self._collection_name, id=id)
|
||||||
|
|
||||||
|
@ -6,6 +6,8 @@ class Field(Enum):
|
|||||||
METADATA_KEY = "metadata"
|
METADATA_KEY = "metadata"
|
||||||
GROUP_KEY = "group_id"
|
GROUP_KEY = "group_id"
|
||||||
VECTOR = "vector"
|
VECTOR = "vector"
|
||||||
|
# Sparse Vector aims to support full text search
|
||||||
|
SPARSE_VECTOR = "sparse_vector"
|
||||||
TEXT_KEY = "text"
|
TEXT_KEY = "text"
|
||||||
PRIMARY_KEY = "id"
|
PRIMARY_KEY = "id"
|
||||||
DOC_ID = "metadata.doc_id"
|
DOC_ID = "metadata.doc_id"
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from pymilvus import MilvusClient, MilvusException # type: ignore
|
from pymilvus import MilvusClient, MilvusException # type: ignore
|
||||||
from pymilvus.milvus_client import IndexParams # type: ignore
|
from pymilvus.milvus_client import IndexParams # type: ignore
|
||||||
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class MilvusConfig(BaseModel):
|
class MilvusConfig(BaseModel):
|
||||||
uri: str
|
"""
|
||||||
token: Optional[str] = None
|
Configuration class for Milvus connection.
|
||||||
user: str
|
"""
|
||||||
password: str
|
|
||||||
batch_size: int = 100
|
uri: str # Milvus server URI
|
||||||
database: str = "default"
|
token: Optional[str] = None # Optional token for authentication
|
||||||
|
user: str # Username for authentication
|
||||||
|
password: str # Password for authentication
|
||||||
|
batch_size: int = 100 # Batch size for operations
|
||||||
|
database: str = "default" # Database name
|
||||||
|
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Validate the configuration values.
|
||||||
|
Raises ValueError if required fields are missing.
|
||||||
|
"""
|
||||||
if not values.get("uri"):
|
if not values.get("uri"):
|
||||||
raise ValueError("config MILVUS_URI is required")
|
raise ValueError("config MILVUS_URI is required")
|
||||||
if not values.get("user"):
|
if not values.get("user"):
|
||||||
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def to_milvus_params(self):
|
def to_milvus_params(self):
|
||||||
|
"""
|
||||||
|
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
"uri": self.uri,
|
"uri": self.uri,
|
||||||
"token": self.token,
|
"token": self.token,
|
||||||
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class MilvusVector(BaseVector):
|
class MilvusVector(BaseVector):
|
||||||
|
"""
|
||||||
|
Milvus vector storage implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||||
super().__init__(collection_name)
|
super().__init__(collection_name)
|
||||||
self._client_config = config
|
self._client_config = config
|
||||||
self._client = self._init_client(config)
|
self._client = self._init_client(config)
|
||||||
self._consistency_level = "Session"
|
self._consistency_level = "Session" # Consistency level for Milvus operations
|
||||||
self._fields: list[str] = []
|
self._fields: list[str] = [] # List of fields in the collection
|
||||||
|
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||||
|
|
||||||
|
def _check_hybrid_search_support(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current Milvus version supports hybrid search.
|
||||||
|
Returns True if the version is >= 2.5.0, otherwise False.
|
||||||
|
"""
|
||||||
|
if not self._client_config.enable_hybrid_search:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
milvus_version = self._client.get_server_version()
|
||||||
|
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||||
|
return False
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the type of vector storage (Milvus).
|
||||||
|
"""
|
||||||
return VectorType.MILVUS
|
return VectorType.MILVUS
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
"""
|
||||||
|
Create a collection and add texts with embeddings.
|
||||||
|
"""
|
||||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||||
self.create_collection(embeddings, metadatas, index_params)
|
self.create_collection(embeddings, metadatas, index_params)
|
||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
"""
|
||||||
|
Add texts and their embeddings to the collection.
|
||||||
|
"""
|
||||||
insert_dict_list = []
|
insert_dict_list = []
|
||||||
for i in range(len(documents)):
|
for i in range(len(documents)):
|
||||||
insert_dict = {
|
insert_dict = {
|
||||||
|
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
|
||||||
|
# function will automatically convert the native text into a sparse vector for us.
|
||||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||||
Field.VECTOR.value: embeddings[i],
|
Field.VECTOR.value: embeddings[i],
|
||||||
Field.METADATA_KEY.value: documents[i].metadata,
|
Field.METADATA_KEY.value: documents[i].metadata,
|
||||||
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
|
|||||||
insert_dict_list.append(insert_dict)
|
insert_dict_list.append(insert_dict)
|
||||||
# Total insert count
|
# Total insert count
|
||||||
total_count = len(insert_dict_list)
|
total_count = len(insert_dict_list)
|
||||||
|
|
||||||
pks: list[str] = []
|
pks: list[str] = []
|
||||||
|
|
||||||
for i in range(0, total_count, 1000):
|
for i in range(0, total_count, 1000):
|
||||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
|
||||||
# Insert into the collection.
|
# Insert into the collection.
|
||||||
|
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||||
try:
|
try:
|
||||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||||
pks.extend(ids)
|
pks.extend(ids)
|
||||||
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
|
|||||||
return pks
|
return pks
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
|
"""
|
||||||
|
Get document IDs by metadata field key and value.
|
||||||
|
"""
|
||||||
result = self._client.query(
|
result = self._client.query(
|
||||||
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
||||||
)
|
)
|
||||||
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str):
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
|
"""
|
||||||
|
Delete documents by metadata field key and value.
|
||||||
|
"""
|
||||||
if self._client.has_collection(self._collection_name):
|
if self._client.has_collection(self._collection_name):
|
||||||
ids = self.get_ids_by_metadata_field(key, value)
|
ids = self.get_ids_by_metadata_field(key, value)
|
||||||
if ids:
|
if ids:
|
||||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
"""
|
||||||
|
Delete documents by their IDs.
|
||||||
|
"""
|
||||||
if self._client.has_collection(self._collection_name):
|
if self._client.has_collection(self._collection_name):
|
||||||
result = self._client.query(
|
result = self._client.query(
|
||||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
||||||
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
|
|||||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
|
"""
|
||||||
|
Delete the entire collection.
|
||||||
|
"""
|
||||||
if self._client.has_collection(self._collection_name):
|
if self._client.has_collection(self._collection_name):
|
||||||
self._client.drop_collection(self._collection_name, None)
|
self._client.drop_collection(self._collection_name, None)
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a text with the given ID exists in the collection.
|
||||||
|
"""
|
||||||
if not self._client.has_collection(self._collection_name):
|
if not self._client.has_collection(self._collection_name):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
|
|||||||
|
|
||||||
return len(result) > 0
|
return len(result) > 0
|
||||||
|
|
||||||
|
def field_exists(self, field: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a field exists in the collection.
|
||||||
|
"""
|
||||||
|
return field in self._fields
|
||||||
|
|
||||||
|
def _process_search_results(
|
||||||
|
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
|
||||||
|
) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Common method to process search results
|
||||||
|
|
||||||
|
:param results: Search results
|
||||||
|
:param output_fields: Fields to be output
|
||||||
|
:param score_threshold: Score threshold for filtering
|
||||||
|
:return: List of documents
|
||||||
|
"""
|
||||||
|
docs = []
|
||||||
|
for result in results[0]:
|
||||||
|
metadata = result["entity"].get(output_fields[1], {})
|
||||||
|
metadata["score"] = result["distance"]
|
||||||
|
|
||||||
|
if result["distance"] > score_threshold:
|
||||||
|
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
# Set search parameters.
|
"""
|
||||||
|
Search for documents by vector similarity.
|
||||||
|
"""
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
data=[query_vector],
|
data=[query_vector],
|
||||||
|
anns_field=Field.VECTOR.value,
|
||||||
limit=kwargs.get("top_k", 4),
|
limit=kwargs.get("top_k", 4),
|
||||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
)
|
)
|
||||||
# Organize results.
|
|
||||||
docs = []
|
return self._process_search_results(
|
||||||
for result in results[0]:
|
results,
|
||||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
metadata["score"] = result["distance"]
|
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
)
|
||||||
if result["distance"] > score_threshold:
|
|
||||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
# milvus/zilliz doesn't support bm25 search
|
"""
|
||||||
return []
|
Search for documents by full-text search (if hybrid search is enabled).
|
||||||
|
"""
|
||||||
|
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||||
|
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = self._client.search(
|
||||||
|
collection_name=self._collection_name,
|
||||||
|
data=[query],
|
||||||
|
anns_field=Field.SPARSE_VECTOR.value,
|
||||||
|
limit=kwargs.get("top_k", 4),
|
||||||
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._process_search_results(
|
||||||
|
results,
|
||||||
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||||
|
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
def create_collection(
|
def create_collection(
|
||||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Create a new collection in Milvus with the specified schema and index parameters.
|
||||||
|
"""
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||||
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
|
|||||||
return
|
return
|
||||||
# Grab the existing collection if it exists
|
# Grab the existing collection if it exists
|
||||||
if not self._client.has_collection(self._collection_name):
|
if not self._client.has_collection(self._collection_name):
|
||||||
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
|
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
|
||||||
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
||||||
|
|
||||||
# Determine embedding dim
|
# Determine embedding dim
|
||||||
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
|
|||||||
if metadatas:
|
if metadatas:
|
||||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||||
|
|
||||||
# Create the text field
|
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
|
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||||
|
fields.append(
|
||||||
|
FieldSchema(
|
||||||
|
Field.CONTENT_KEY.value,
|
||||||
|
DataType.VARCHAR,
|
||||||
|
max_length=65_535,
|
||||||
|
enable_analyzer=self._hybrid_search_enabled,
|
||||||
|
)
|
||||||
|
)
|
||||||
# Create the primary key field
|
# Create the primary key field
|
||||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||||
# Create the vector field, supports binary or float vectors
|
# Create the vector field, supports binary or float vectors
|
||||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||||
|
# Create Sparse Vector Index for the collection
|
||||||
|
if self._hybrid_search_enabled:
|
||||||
|
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
|
||||||
|
|
||||||
# Create the schema for the collection
|
|
||||||
schema = CollectionSchema(fields)
|
schema = CollectionSchema(fields)
|
||||||
|
|
||||||
|
# Create custom function to support text to sparse vector by BM25
|
||||||
|
if self._hybrid_search_enabled:
|
||||||
|
bm25_function = Function(
|
||||||
|
name="text_bm25_emb",
|
||||||
|
input_field_names=[Field.CONTENT_KEY.value],
|
||||||
|
output_field_names=[Field.SPARSE_VECTOR.value],
|
||||||
|
function_type=FunctionType.BM25,
|
||||||
|
)
|
||||||
|
schema.add_function(bm25_function)
|
||||||
|
|
||||||
for x in schema.fields:
|
for x in schema.fields:
|
||||||
self._fields.append(x.name)
|
self._fields.append(x.name)
|
||||||
# Since primary field is auto-id, no need to track it
|
# Since primary field is auto-id, no need to track it
|
||||||
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
|
|||||||
index_params_obj = IndexParams()
|
index_params_obj = IndexParams()
|
||||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||||
|
|
||||||
|
# Create Sparse Vector Index for the collection
|
||||||
|
if self._hybrid_search_enabled:
|
||||||
|
index_params_obj.add_index(
|
||||||
|
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
|
||||||
|
)
|
||||||
|
|
||||||
# Create the collection
|
# Create the collection
|
||||||
collection_name = self._collection_name
|
|
||||||
self._client.create_collection(
|
self._client.create_collection(
|
||||||
collection_name=collection_name,
|
collection_name=self._collection_name,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
index_params=index_params_obj,
|
index_params=index_params_obj,
|
||||||
consistency_level=self._consistency_level,
|
consistency_level=self._consistency_level,
|
||||||
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
|
|||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
def _init_client(self, config) -> MilvusClient:
|
def _init_client(self, config) -> MilvusClient:
|
||||||
|
"""
|
||||||
|
Initialize and return a Milvus client.
|
||||||
|
"""
|
||||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorFactory(AbstractVectorFactory):
|
class MilvusVectorFactory(AbstractVectorFactory):
|
||||||
|
"""
|
||||||
|
Factory class for creating MilvusVector instances.
|
||||||
|
"""
|
||||||
|
|
||||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||||
|
"""
|
||||||
|
Initialize a MilvusVector instance for the given dataset.
|
||||||
|
"""
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
collection_name = class_prefix
|
collection_name = class_prefix
|
||||||
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
|||||||
user=dify_config.MILVUS_USER or "",
|
user=dify_config.MILVUS_USER or "",
|
||||||
password=dify_config.MILVUS_PASSWORD or "",
|
password=dify_config.MILVUS_PASSWORD or "",
|
||||||
database=dify_config.MILVUS_DATABASE or "",
|
database=dify_config.MILVUS_DATABASE or "",
|
||||||
|
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
|
|||||||
return results.row_count > 0
|
return results.row_count > 0
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
self._client.command(
|
self._client.command(
|
||||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||||
)
|
)
|
||||||
|
@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
|
|||||||
return bool(cur.rowcount != 0)
|
return bool(cur.rowcount != 0)
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||||
|
@ -167,6 +167,8 @@ class OracleVector(BaseVector):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||||
|
|
||||||
|
@ -129,6 +129,11 @@ class PGVector(BaseVector):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||||
|
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||||
|
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||||
|
|
||||||
|
@ -140,6 +140,8 @@ class TencentVector(BaseVector):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||||
)
|
)
|
||||||
if not tidb_auth_binding:
|
if not tidb_auth_binding:
|
||||||
idle_tidb_auth_binding = (
|
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||||
db.session.query(TidbAuthBinding)
|
tidb_auth_binding = (
|
||||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
db.session.query(TidbAuthBinding)
|
||||||
.limit(1)
|
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
)
|
)
|
||||||
if idle_tidb_auth_binding:
|
if tidb_auth_binding:
|
||||||
idle_tidb_auth_binding.active = True
|
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
|
||||||
db.session.commit()
|
else:
|
||||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
idle_tidb_auth_binding = (
|
||||||
else:
|
|
||||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
|
||||||
tidb_auth_binding = (
|
|
||||||
db.session.query(TidbAuthBinding)
|
db.session.query(TidbAuthBinding)
|
||||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||||
|
.limit(1)
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
)
|
)
|
||||||
if tidb_auth_binding:
|
if idle_tidb_auth_binding:
|
||||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
idle_tidb_auth_binding.active = True
|
||||||
|
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||||
|
db.session.commit()
|
||||||
|
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||||
else:
|
else:
|
||||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||||
dify_config.TIDB_PROJECT_ID or "",
|
dify_config.TIDB_PROJECT_ID or "",
|
||||||
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||||||
db.session.add(new_tidb_auth_binding)
|
db.session.add(new_tidb_auth_binding)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||||
|
|
||||||
|
@ -90,6 +90,12 @@ class Vector:
|
|||||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
|
|
||||||
return ElasticSearchVectorFactory
|
return ElasticSearchVectorFactory
|
||||||
|
case VectorType.ELASTICSEARCH_JA:
|
||||||
|
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||||
|
ElasticSearchJaVectorFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ElasticSearchJaVectorFactory
|
||||||
case VectorType.TIDB_VECTOR:
|
case VectorType.TIDB_VECTOR:
|
||||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ class VectorType(StrEnum):
|
|||||||
TENCENT = "tencent"
|
TENCENT = "tencent"
|
||||||
ORACLE = "oracle"
|
ORACLE = "oracle"
|
||||||
ELASTICSEARCH = "elasticsearch"
|
ELASTICSEARCH = "elasticsearch"
|
||||||
|
ELASTICSEARCH_JA = "elasticsearch-ja"
|
||||||
LINDORM = "lindorm"
|
LINDORM = "lindorm"
|
||||||
COUCHBASE = "couchbase"
|
COUCHBASE = "couchbase"
|
||||||
BAIDU = "baidu"
|
BAIDU = "baidu"
|
||||||
|
@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
|
|||||||
self._file_cache_key = file_cache_key
|
self._file_cache_key = file_cache_key
|
||||||
|
|
||||||
def extract(self) -> list[Document]:
|
def extract(self) -> list[Document]:
|
||||||
plaintext_file_key = ""
|
|
||||||
plaintext_file_exists = False
|
plaintext_file_exists = False
|
||||||
if self._file_cache_key:
|
if self._file_cache_key:
|
||||||
try:
|
try:
|
||||||
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
|
|||||||
text = "\n\n".join(text_list)
|
text = "\n\n".join(text_list)
|
||||||
|
|
||||||
# save plaintext file for caching
|
# save plaintext file for caching
|
||||||
if not plaintext_file_exists and plaintext_file_key:
|
if not plaintext_file_exists and self._file_cache_key:
|
||||||
storage.save(plaintext_file_key, text.encode("utf-8"))
|
storage.save(self._file_cache_key, text.encode("utf-8"))
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
child_nodes = self._split_child_nodes(
|
child_nodes = self._split_child_nodes(
|
||||||
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||||
)
|
)
|
||||||
|
if kwargs.get("preview"):
|
||||||
|
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
|
||||||
|
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
|
||||||
|
|
||||||
document.children = child_nodes
|
document.children = child_nodes
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = str(uuid.uuid4())
|
||||||
hash = helper.generate_text_hash(document.page_content)
|
hash = helper.generate_text_hash(document.page_content)
|
||||||
|
@ -212,8 +212,23 @@ class ApiTool(Tool):
|
|||||||
else:
|
else:
|
||||||
body = body
|
body = body
|
||||||
|
|
||||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
if method in {
|
||||||
response: httpx.Response = getattr(ssrf_proxy, method)(
|
"get",
|
||||||
|
"head",
|
||||||
|
"post",
|
||||||
|
"put",
|
||||||
|
"delete",
|
||||||
|
"patch",
|
||||||
|
"options",
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"HEAD",
|
||||||
|
"OPTIONS",
|
||||||
|
}:
|
||||||
|
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
|
||||||
url,
|
url,
|
||||||
params=params,
|
params=params,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -2,14 +2,18 @@ import csv
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import cast
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
import docx
|
import docx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pypdfium2 # type: ignore
|
import pypdfium2 # type: ignore
|
||||||
import yaml # type: ignore
|
import yaml # type: ignore
|
||||||
|
from docx.table import Table
|
||||||
|
from docx.text.paragraph import Paragraph
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file import File, FileTransferMethod, file_manager
|
from core.file import File, FileTransferMethod, file_manager
|
||||||
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
|||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
node_id: str,
|
||||||
|
node_data: DocumentExtractorNodeData,
|
||||||
|
) -> Mapping[str, Sequence[str]]:
|
||||||
|
"""
|
||||||
|
Extract variable selector to variable mapping
|
||||||
|
:param graph_config: graph config
|
||||||
|
:param node_id: node id
|
||||||
|
:param node_data: node data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return {node_id + ".files": node_data.variable_selector}
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||||
"""Extract text from a file based on its MIME type."""
|
"""Extract text from a file based on its MIME type."""
|
||||||
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
|||||||
doc_file = io.BytesIO(file_content)
|
doc_file = io.BytesIO(file_content)
|
||||||
doc = docx.Document(doc_file)
|
doc = docx.Document(doc_file)
|
||||||
text = []
|
text = []
|
||||||
# Process paragraphs
|
|
||||||
for paragraph in doc.paragraphs:
|
|
||||||
if paragraph.text.strip():
|
|
||||||
text.append(paragraph.text)
|
|
||||||
|
|
||||||
# Process tables
|
# Keep track of paragraph and table positions
|
||||||
for table in doc.tables:
|
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||||
# Table header
|
|
||||||
try:
|
# Process paragraphs and tables
|
||||||
# table maybe cause errors so ignore it.
|
for i, paragraph in enumerate(doc.paragraphs):
|
||||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
if paragraph.text.strip():
|
||||||
|
content_items.append((i, "paragraph", paragraph))
|
||||||
|
|
||||||
|
for i, table in enumerate(doc.tables):
|
||||||
|
content_items.append((i, "table", table))
|
||||||
|
|
||||||
|
# Sort content items based on their original position
|
||||||
|
content_items.sort(key=operator.itemgetter(0))
|
||||||
|
|
||||||
|
# Process sorted content
|
||||||
|
for _, item_type, item in content_items:
|
||||||
|
if item_type == "paragraph":
|
||||||
|
if isinstance(item, Table):
|
||||||
|
continue
|
||||||
|
text.append(item.text)
|
||||||
|
elif item_type == "table":
|
||||||
|
# Process tables
|
||||||
|
if not isinstance(item, Table):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
# Check if any cell in the table has text
|
# Check if any cell in the table has text
|
||||||
has_content = False
|
has_content = False
|
||||||
for row in table.rows:
|
for row in item.rows:
|
||||||
if any(cell.text.strip() for cell in row.cells):
|
if any(cell.text.strip() for cell in row.cells):
|
||||||
has_content = True
|
has_content = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if has_content:
|
if has_content:
|
||||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
|
||||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
markdown_table = f"| {' | '.join(cell_texts)} |\n"
|
||||||
for row in table.rows[1:]:
|
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
|
||||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
|
||||||
|
for row in item.rows[1:]:
|
||||||
|
# Replace newlines with <br> in each cell
|
||||||
|
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
|
||||||
|
markdown_table += "| " + " | ".join(row_cells) + " |\n"
|
||||||
|
|
||||||
text.append(markdown_table)
|
text.append(markdown_table)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return "\n".join(text)
|
return "\n".join(text)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||||
|
|
||||||
|
@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
|
|||||||
Code Node Data.
|
Code Node Data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
method: Literal["get", "post", "put", "patch", "delete", "head"]
|
method: Literal[
|
||||||
|
"get",
|
||||||
|
"post",
|
||||||
|
"put",
|
||||||
|
"patch",
|
||||||
|
"delete",
|
||||||
|
"head",
|
||||||
|
"options",
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"HEAD",
|
||||||
|
"OPTIONS",
|
||||||
|
]
|
||||||
url: str
|
url: str
|
||||||
authorization: HttpRequestNodeAuthorization
|
authorization: HttpRequestNodeAuthorization
|
||||||
headers: str
|
headers: str
|
||||||
|
@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
|
|||||||
|
|
||||||
|
|
||||||
class Executor:
|
class Executor:
|
||||||
method: Literal["get", "head", "post", "put", "delete", "patch"]
|
method: Literal[
|
||||||
|
"get",
|
||||||
|
"head",
|
||||||
|
"post",
|
||||||
|
"put",
|
||||||
|
"delete",
|
||||||
|
"patch",
|
||||||
|
"options",
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"HEAD",
|
||||||
|
"OPTIONS",
|
||||||
|
]
|
||||||
url: str
|
url: str
|
||||||
params: list[tuple[str, str]] | None
|
params: list[tuple[str, str]] | None
|
||||||
content: str | bytes | None
|
content: str | bytes | None
|
||||||
@ -67,12 +82,6 @@ class Executor:
|
|||||||
node_data.authorization.config.api_key
|
node_data.authorization.config.api_key
|
||||||
).text
|
).text
|
||||||
|
|
||||||
# check if node_data.url is a valid URL
|
|
||||||
if not node_data.url:
|
|
||||||
raise InvalidURLError("url is required")
|
|
||||||
if not node_data.url.startswith(("http://", "https://")):
|
|
||||||
raise InvalidURLError("url should start with http:// or https://")
|
|
||||||
|
|
||||||
self.url: str = node_data.url
|
self.url: str = node_data.url
|
||||||
self.method = node_data.method
|
self.method = node_data.method
|
||||||
self.auth = node_data.authorization
|
self.auth = node_data.authorization
|
||||||
@ -99,6 +108,12 @@ class Executor:
|
|||||||
def _init_url(self):
|
def _init_url(self):
|
||||||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||||
|
|
||||||
|
# check if url is a valid URL
|
||||||
|
if not self.url:
|
||||||
|
raise InvalidURLError("url is required")
|
||||||
|
if not self.url.startswith(("http://", "https://")):
|
||||||
|
raise InvalidURLError("url should start with http:// or https://")
|
||||||
|
|
||||||
def _init_params(self):
|
def _init_params(self):
|
||||||
"""
|
"""
|
||||||
Almost same as _init_headers(), difference:
|
Almost same as _init_headers(), difference:
|
||||||
@ -158,7 +173,10 @@ class Executor:
|
|||||||
if len(data) != 1:
|
if len(data) != 1:
|
||||||
raise RequestBodyError("json body type should have exactly one item")
|
raise RequestBodyError("json body type should have exactly one item")
|
||||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||||
json_object = json.loads(json_string, strict=False)
|
try:
|
||||||
|
json_object = json.loads(json_string, strict=False)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
|
||||||
self.json = json_object
|
self.json = json_object
|
||||||
# self.json = self._parse_object_contains_variables(json_object)
|
# self.json = self._parse_object_contains_variables(json_object)
|
||||||
case "binary":
|
case "binary":
|
||||||
@ -246,7 +264,22 @@ class Executor:
|
|||||||
"""
|
"""
|
||||||
do http request depending on api bundle
|
do http request depending on api bundle
|
||||||
"""
|
"""
|
||||||
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
|
if self.method not in {
|
||||||
|
"get",
|
||||||
|
"head",
|
||||||
|
"post",
|
||||||
|
"put",
|
||||||
|
"delete",
|
||||||
|
"patch",
|
||||||
|
"options",
|
||||||
|
"GET",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
"PATCH",
|
||||||
|
"DELETE",
|
||||||
|
"HEAD",
|
||||||
|
"OPTIONS",
|
||||||
|
}:
|
||||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||||
|
|
||||||
request_args = {
|
request_args = {
|
||||||
@ -263,7 +296,7 @@ class Executor:
|
|||||||
}
|
}
|
||||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||||
try:
|
try:
|
||||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||||
raise HttpRequestNodeError(str(e))
|
raise HttpRequestNodeError(str(e))
|
||||||
# FIXME: fix type ignore, this maybe httpx type issue
|
# FIXME: fix type ignore, this maybe httpx type issue
|
||||||
|
@ -340,6 +340,10 @@ class WorkflowEntry:
|
|||||||
):
|
):
|
||||||
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
||||||
|
|
||||||
|
# environment variable already exist in variable pool, not from user inputs
|
||||||
|
if variable_pool.get(variable_selector):
|
||||||
|
continue
|
||||||
|
|
||||||
# fetch variable node id from variable selector
|
# fetch variable node id from variable selector
|
||||||
variable_node_id = variable_selector[0]
|
variable_node_id = variable_selector[0]
|
||||||
variable_key_list = variable_selector[1:]
|
variable_key_list = variable_selector[1:]
|
||||||
|
@ -33,6 +33,7 @@ else
|
|||||||
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
||||||
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
||||||
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
||||||
|
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
|
||||||
--timeout ${GUNICORN_TIMEOUT:-200} \
|
--timeout ${GUNICORN_TIMEOUT:-200} \
|
||||||
app:app
|
app:app
|
||||||
fi
|
fi
|
||||||
|
@ -46,7 +46,7 @@ def init_app(app: DifyApp):
|
|||||||
timezone = pytz.timezone(log_tz)
|
timezone = pytz.timezone(log_tz)
|
||||||
|
|
||||||
def time_converter(seconds):
|
def time_converter(seconds):
|
||||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||||
|
|
||||||
for handler in logging.root.handlers:
|
for handler in logging.root.handlers:
|
||||||
if handler.formatter:
|
if handler.formatter:
|
||||||
|
@ -158,7 +158,7 @@ def _build_from_remote_url(
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
) -> File:
|
) -> File:
|
||||||
url = mapping.get("url")
|
url = mapping.get("url") or mapping.get("remote_url")
|
||||||
if not url:
|
if not url:
|
||||||
raise ValueError("Invalid file url")
|
raise ValueError("Invalid file url")
|
||||||
|
|
||||||
|
@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource):
|
|||||||
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise ValueError(f"Error fetching block parent page ID: {response_json.message}")
|
message = response_json.get("message", "unknown error")
|
||||||
|
raise ValueError(f"Error fetching block parent page ID: {message}")
|
||||||
parent = response_json["parent"]
|
parent = response_json["parent"]
|
||||||
parent_type = parent["type"]
|
parent_type = parent["type"]
|
||||||
if parent_type == "block_id":
|
if parent_type == "block_id":
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
"""change workflow_runs.total_tokens to bigint
|
||||||
|
|
||||||
|
Revision ID: a91b476a53de
|
||||||
|
Revises: 923752d42eb6
|
||||||
|
Create Date: 2025-01-01 20:00:01.207369
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'a91b476a53de'
|
||||||
|
down_revision = '923752d42eb6'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('total_tokens',
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
type_=sa.BigInteger(),
|
||||||
|
existing_nullable=False,
|
||||||
|
existing_server_default=sa.text('0'))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('total_tokens',
|
||||||
|
existing_type=sa.BigInteger(),
|
||||||
|
type_=sa.INTEGER(),
|
||||||
|
existing_nullable=False,
|
||||||
|
existing_server_default=sa.text('0'))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@ -415,8 +415,8 @@ class WorkflowRun(Base):
|
|||||||
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
|
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
|
||||||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
|
||||||
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
|
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||||
total_steps = db.Column(db.Integer, server_default=db.text("0"))
|
total_steps = db.Column(db.Integer, server_default=db.text("0"))
|
||||||
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
|
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
|
||||||
created_by = db.Column(StringUUID, nullable=False)
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
|
2397
api/poetry.lock
generated
2397
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -71,7 +71,7 @@ pyjwt = "~2.8.0"
|
|||||||
pypdfium2 = "~4.30.0"
|
pypdfium2 = "~4.30.0"
|
||||||
python = ">=3.11,<3.13"
|
python = ">=3.11,<3.13"
|
||||||
python-docx = "~1.1.0"
|
python-docx = "~1.1.0"
|
||||||
python-dotenv = "1.0.0"
|
python-dotenv = "1.0.1"
|
||||||
pyyaml = "~6.0.1"
|
pyyaml = "~6.0.1"
|
||||||
readabilipy = "0.2.0"
|
readabilipy = "0.2.0"
|
||||||
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
||||||
@ -82,7 +82,7 @@ scikit-learn = "~1.5.1"
|
|||||||
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
||||||
sqlalchemy = "~2.0.29"
|
sqlalchemy = "~2.0.29"
|
||||||
starlette = "0.41.0"
|
starlette = "0.41.0"
|
||||||
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
|
tencentcloud-sdk-python-hunyuan = "~3.0.1294"
|
||||||
tiktoken = "~0.8.0"
|
tiktoken = "~0.8.0"
|
||||||
tokenizers = "~0.15.0"
|
tokenizers = "~0.15.0"
|
||||||
transformers = "~4.35.0"
|
transformers = "~4.35.0"
|
||||||
@ -92,7 +92,7 @@ validators = "0.21.0"
|
|||||||
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
|
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
|
||||||
websocket-client = "~1.7.0"
|
websocket-client = "~1.7.0"
|
||||||
xinference-client = "0.15.2"
|
xinference-client = "0.15.2"
|
||||||
yarl = "~1.9.4"
|
yarl = "~1.18.3"
|
||||||
youtube-transcript-api = "~0.6.2"
|
youtube-transcript-api = "~0.6.2"
|
||||||
zhipuai = "~2.1.5"
|
zhipuai = "~2.1.5"
|
||||||
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
||||||
@ -157,7 +157,7 @@ opensearch-py = "2.4.0"
|
|||||||
oracledb = "~2.2.1"
|
oracledb = "~2.2.1"
|
||||||
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
||||||
pgvector = "0.2.5"
|
pgvector = "0.2.5"
|
||||||
pymilvus = "~2.4.4"
|
pymilvus = "~2.5.0"
|
||||||
pymochow = "1.3.1"
|
pymochow = "1.3.1"
|
||||||
pyobvector = "~0.1.6"
|
pyobvector = "~0.1.6"
|
||||||
qdrant-client = "1.7.3"
|
qdrant-client = "1.7.3"
|
||||||
|
@ -168,23 +168,6 @@ def clean_unused_datasets_task():
|
|||||||
else:
|
else:
|
||||||
plan = plan_cache.decode()
|
plan = plan_cache.decode()
|
||||||
if plan == "sandbox":
|
if plan == "sandbox":
|
||||||
# add auto disable log
|
|
||||||
documents = (
|
|
||||||
db.session.query(Document)
|
|
||||||
.filter(
|
|
||||||
Document.dataset_id == dataset.id,
|
|
||||||
Document.enabled == True,
|
|
||||||
Document.archived == False,
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
for document in documents:
|
|
||||||
dataset_auto_disable_log = DatasetAutoDisableLog(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
document_id=document.id,
|
|
||||||
)
|
|
||||||
db.session.add(dataset_auto_disable_log)
|
|
||||||
# remove index
|
# remove index
|
||||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||||
index_processor.clean(dataset, None)
|
index_processor.clean(dataset, None)
|
||||||
|
@ -66,7 +66,7 @@ class TokenPair(BaseModel):
|
|||||||
|
|
||||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||||
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
|
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
|
||||||
class AccountService:
|
class AccountService:
|
||||||
|
@ -2,6 +2,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml # type: ignore
|
||||||
@ -124,7 +125,7 @@ class AppDslService:
|
|||||||
raise ValueError(f"Invalid import_mode: {import_mode}")
|
raise ValueError(f"Invalid import_mode: {import_mode}")
|
||||||
|
|
||||||
# Get YAML content
|
# Get YAML content
|
||||||
content: bytes | str = b""
|
content: str = ""
|
||||||
if mode == ImportMode.YAML_URL:
|
if mode == ImportMode.YAML_URL:
|
||||||
if not yaml_url:
|
if not yaml_url:
|
||||||
return Import(
|
return Import(
|
||||||
@ -133,13 +134,17 @@ class AppDslService:
|
|||||||
error="yaml_url is required when import_mode is yaml-url",
|
error="yaml_url is required when import_mode is yaml-url",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# tricky way to handle url from github to github raw url
|
parsed_url = urlparse(yaml_url)
|
||||||
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")):
|
if (
|
||||||
|
parsed_url.scheme == "https"
|
||||||
|
and parsed_url.netloc == "github.com"
|
||||||
|
and parsed_url.path.endswith((".yml", ".yaml"))
|
||||||
|
):
|
||||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||||
yaml_url = yaml_url.replace("/blob/", "/")
|
yaml_url = yaml_url.replace("/blob/", "/")
|
||||||
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
content = response.content
|
content = response.content.decode()
|
||||||
|
|
||||||
if len(content) > DSL_MAX_SIZE:
|
if len(content) > DSL_MAX_SIZE:
|
||||||
return Import(
|
return Import(
|
||||||
|
@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
|
|||||||
|
|
||||||
|
|
||||||
class AppService:
|
class AppService:
|
||||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
|
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
|
||||||
"""
|
"""
|
||||||
Get app list with pagination
|
Get app list with pagination
|
||||||
|
:param user_id: user id
|
||||||
:param tenant_id: tenant id
|
:param tenant_id: tenant id
|
||||||
:param args: request args
|
:param args: request args
|
||||||
:return:
|
:return:
|
||||||
@ -44,6 +45,8 @@ class AppService:
|
|||||||
elif args["mode"] == "channel":
|
elif args["mode"] == "channel":
|
||||||
filters.append(App.mode == AppMode.CHANNEL.value)
|
filters.append(App.mode == AppMode.CHANNEL.value)
|
||||||
|
|
||||||
|
if args.get("is_created_by_me", False):
|
||||||
|
filters.append(App.created_by == user_id)
|
||||||
if args.get("name"):
|
if args.get("name"):
|
||||||
name = args["name"][:30]
|
name = args["name"][:30]
|
||||||
filters.append(App.name.ilike(f"%{name}%"))
|
filters.append(App.name.ilike(f"%{name}%"))
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||||
@ -17,7 +17,6 @@ class BillingService:
|
|||||||
params = {"tenant_id": tenant_id}
|
params = {"tenant_id": tenant_id}
|
||||||
|
|
||||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||||
|
|
||||||
return billing_info
|
return billing_info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -47,12 +46,13 @@ class BillingService:
|
|||||||
retry=retry_if_exception_type(httpx.RequestError),
|
retry=retry_if_exception_type(httpx.RequestError),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
|
||||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||||
|
|
||||||
url = f"{cls.base_url}{endpoint}"
|
url = f"{cls.base_url}{endpoint}"
|
||||||
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
response = httpx.request(method, url, json=json, params=params, headers=headers)
|
||||||
|
if method == "GET" and response.status_code != httpx.codes.OK:
|
||||||
|
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -86,7 +86,7 @@ class DatasetService:
|
|||||||
else:
|
else:
|
||||||
return [], 0
|
return [], 0
|
||||||
else:
|
else:
|
||||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
if user.current_role != TenantAccountRole.OWNER:
|
||||||
# show all datasets that the user has permission to access
|
# show all datasets that the user has permission to access
|
||||||
if permitted_dataset_ids:
|
if permitted_dataset_ids:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
@ -382,7 +382,7 @@ class DatasetService:
|
|||||||
if dataset.tenant_id != user.current_tenant_id:
|
if dataset.tenant_id != user.current_tenant_id:
|
||||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
if user.current_role != TenantAccountRole.OWNER:
|
||||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
|
||||||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
@ -404,7 +404,7 @@ class DatasetService:
|
|||||||
if not user:
|
if not user:
|
||||||
raise ValueError("User not found")
|
raise ValueError("User not found")
|
||||||
|
|
||||||
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
|
if user.current_role != TenantAccountRole.OWNER:
|
||||||
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
|
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
|
||||||
if dataset.created_by != user.id:
|
if dataset.created_by != user.id:
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
@ -434,6 +434,12 @@ class DatasetService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
||||||
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
|
||||||
|
return {
|
||||||
|
"document_ids": [],
|
||||||
|
"count": 0,
|
||||||
|
}
|
||||||
# get recent 30 days auto disable logs
|
# get recent 30 days auto disable logs
|
||||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
||||||
@ -786,13 +792,19 @@ class DocumentService:
|
|||||||
dataset.indexing_technique = knowledge_config.indexing_technique
|
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||||
if knowledge_config.indexing_technique == "high_quality":
|
if knowledge_config.indexing_technique == "high_quality":
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
embedding_model = model_manager.get_default_model_instance(
|
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
dataset_embedding_model = knowledge_config.embedding_model
|
||||||
)
|
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||||
dataset.embedding_model = embedding_model.model
|
else:
|
||||||
dataset.embedding_model_provider = embedding_model.provider
|
embedding_model = model_manager.get_default_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||||
|
)
|
||||||
|
dataset_embedding_model = embedding_model.model
|
||||||
|
dataset_embedding_model_provider = embedding_model.provider
|
||||||
|
dataset.embedding_model = dataset_embedding_model
|
||||||
|
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
embedding_model.provider, embedding_model.model
|
dataset_embedding_model_provider, dataset_embedding_model
|
||||||
)
|
)
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
if not dataset.retrieval_model:
|
if not dataset.retrieval_model:
|
||||||
@ -804,7 +816,11 @@ class DocumentService:
|
|||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
|
dataset.retrieval_model = (
|
||||||
|
knowledge_config.retrieval_model.model_dump()
|
||||||
|
if knowledge_config.retrieval_model
|
||||||
|
else default_retrieval_model
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
if knowledge_config.original_document_id:
|
if knowledge_config.original_document_id:
|
||||||
|
@ -27,7 +27,7 @@ class WorkflowAppService:
|
|||||||
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||||
|
|
||||||
if keyword:
|
if keyword:
|
||||||
keyword_like_val = f"%{args['keyword'][:30]}%"
|
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
|
||||||
keyword_conditions = [
|
keyword_conditions = [
|
||||||
WorkflowRun.inputs.ilike(keyword_like_val),
|
WorkflowRun.inputs.ilike(keyword_like_val),
|
||||||
WorkflowRun.outputs.ilike(keyword_like_val),
|
WorkflowRun.outputs.ilike(keyword_like_val),
|
||||||
|
@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception("Dataset not found")
|
raise Exception("Dataset not found")
|
||||||
index_type = dataset.doc_form
|
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
if action == "remove":
|
if action == "remove":
|
||||||
index_processor.clean(dataset, None, with_keywords=False)
|
index_processor.clean(dataset, None, with_keywords=False)
|
||||||
@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||||
)
|
)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
# clean collection
|
||||||
|
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = GPUStackSpeech2TextModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="faster-whisper-medium",
|
||||||
|
credentials={
|
||||||
|
"endpoint_url": "invalid_url",
|
||||||
|
"api_key": "invalid_api_key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model="faster-whisper-medium",
|
||||||
|
credentials={
|
||||||
|
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = GPUStackSpeech2TextModel()
|
||||||
|
|
||||||
|
# Get the directory of the current file
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# Get assets directory
|
||||||
|
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||||
|
|
||||||
|
# Construct the path to the audio file
|
||||||
|
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||||
|
|
||||||
|
file = Path(audio_file_path).read_bytes()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="faster-whisper-medium",
|
||||||
|
credentials={
|
||||||
|
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||||
|
},
|
||||||
|
file=file,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
@ -0,0 +1,24 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = GPUStackText2SpeechModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="cosyvoice-300m-sft",
|
||||||
|
tenant_id="test",
|
||||||
|
credentials={
|
||||||
|
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||||
|
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||||
|
},
|
||||||
|
content_text="Hello world",
|
||||||
|
voice="Chinese Female",
|
||||||
|
)
|
||||||
|
|
||||||
|
content = b""
|
||||||
|
for chunk in result:
|
||||||
|
content += chunk
|
||||||
|
|
||||||
|
assert content != b""
|
@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def search_by_full_text(self):
|
def search_by_full_text(self):
|
||||||
# milvus dos not support full text searching yet in < 2.3.x
|
# milvus support BM25 full text search after version 2.5.0-beta
|
||||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||||
assert len(hits_by_full_text) == 0
|
assert len(hits_by_full_text) >= 0
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self):
|
def get_ids_by_metadata_field(self):
|
||||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||||
|
@ -2,7 +2,7 @@ version: '3'
|
|||||||
services:
|
services:
|
||||||
# API service
|
# API service
|
||||||
api:
|
api:
|
||||||
image: langgenius/dify-api:0.14.2
|
image: langgenius/dify-api:0.15.0
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Startup mode, 'api' starts the API server.
|
# Startup mode, 'api' starts the API server.
|
||||||
@ -227,7 +227,7 @@ services:
|
|||||||
# worker service
|
# worker service
|
||||||
# The Celery worker for processing the queue.
|
# The Celery worker for processing the queue.
|
||||||
worker:
|
worker:
|
||||||
image: langgenius/dify-api:0.14.2
|
image: langgenius/dify-api:0.15.0
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
CONSOLE_WEB_URL: ''
|
CONSOLE_WEB_URL: ''
|
||||||
@ -397,7 +397,7 @@ services:
|
|||||||
|
|
||||||
# Frontend web application.
|
# Frontend web application.
|
||||||
web:
|
web:
|
||||||
image: langgenius/dify-web:0.14.2
|
image: langgenius/dify-web:0.15.0
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
|
||||||
|
@ -105,6 +105,9 @@ FILES_ACCESS_TIMEOUT=300
|
|||||||
# Access token expiration time in minutes
|
# Access token expiration time in minutes
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||||
|
|
||||||
|
# Refresh token expiration time in days
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
|
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
|
||||||
APP_MAX_ACTIVE_REQUESTS=0
|
APP_MAX_ACTIVE_REQUESTS=0
|
||||||
APP_MAX_EXECUTION_TIME=1200
|
APP_MAX_EXECUTION_TIME=1200
|
||||||
@ -123,10 +126,13 @@ DIFY_PORT=5001
|
|||||||
# The number of API server workers, i.e., the number of workers.
|
# The number of API server workers, i.e., the number of workers.
|
||||||
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
|
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
|
||||||
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
|
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
|
||||||
SERVER_WORKER_AMOUNT=
|
SERVER_WORKER_AMOUNT=1
|
||||||
|
|
||||||
# Defaults to gevent. If using windows, it can be switched to sync or solo.
|
# Defaults to gevent. If using windows, it can be switched to sync or solo.
|
||||||
SERVER_WORKER_CLASS=
|
SERVER_WORKER_CLASS=gevent
|
||||||
|
|
||||||
|
# Default number of worker connections, the default is 10.
|
||||||
|
SERVER_WORKER_CONNECTIONS=10
|
||||||
|
|
||||||
# Similar to SERVER_WORKER_CLASS.
|
# Similar to SERVER_WORKER_CLASS.
|
||||||
# If using windows, it can be switched to sync or solo.
|
# If using windows, it can be switched to sync or solo.
|
||||||
@ -377,7 +383,7 @@ SUPABASE_URL=your-server-url
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
|
|
||||||
# The type of vector store to use.
|
# The type of vector store to use.
|
||||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
|
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
|
||||||
VECTOR_STORE=weaviate
|
VECTOR_STORE=weaviate
|
||||||
|
|
||||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||||
@ -397,6 +403,7 @@ MILVUS_URI=http://127.0.0.1:19530
|
|||||||
MILVUS_TOKEN=
|
MILVUS_TOKEN=
|
||||||
MILVUS_USER=root
|
MILVUS_USER=root
|
||||||
MILVUS_PASSWORD=Milvus
|
MILVUS_PASSWORD=Milvus
|
||||||
|
MILVUS_ENABLE_HYBRID_SEARCH=False
|
||||||
|
|
||||||
# MyScale configuration, only available when VECTOR_STORE is `myscale`
|
# MyScale configuration, only available when VECTOR_STORE is `myscale`
|
||||||
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
|
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
|
||||||
@ -506,7 +513,7 @@ TENCENT_VECTOR_DB_SHARD=1
|
|||||||
TENCENT_VECTOR_DB_REPLICAS=2
|
TENCENT_VECTOR_DB_REPLICAS=2
|
||||||
|
|
||||||
# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch`
|
# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch`
|
||||||
ELASTICSEARCH_HOST=0.0.0.0
|
ELASTICSEARCH_HOST=elasticsearch
|
||||||
ELASTICSEARCH_PORT=9200
|
ELASTICSEARCH_PORT=9200
|
||||||
ELASTICSEARCH_USERNAME=elastic
|
ELASTICSEARCH_USERNAME=elastic
|
||||||
ELASTICSEARCH_PASSWORD=elastic
|
ELASTICSEARCH_PASSWORD=elastic
|
||||||
@ -923,6 +930,9 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
|||||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||||
MAX_SUBMIT_COUNT=100
|
MAX_SUBMIT_COUNT=100
|
||||||
|
|
||||||
|
# The maximum number of top-k value for RAG.
|
||||||
|
TOP_K_MAX_VALUE=10
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# Plugin Daemon Configuration
|
# Plugin Daemon Configuration
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -947,3 +957,4 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
|
|||||||
|
|
||||||
MARKETPLACE_ENABLED=true
|
MARKETPLACE_ENABLED=true
|
||||||
MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev
|
MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev
|
||||||
|
|
||||||
|
@ -73,6 +73,7 @@ services:
|
|||||||
CSP_WHITELIST: ${CSP_WHITELIST:-}
|
CSP_WHITELIST: ${CSP_WHITELIST:-}
|
||||||
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
||||||
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
|
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
|
||||||
|
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
|
||||||
|
|
||||||
# The postgres database.
|
# The postgres database.
|
||||||
db:
|
db:
|
||||||
@ -92,7 +93,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/db/data:/var/lib/postgresql/data
|
- ./volumes/db/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'pg_isready']
|
test: [ 'CMD', 'pg_isready' ]
|
||||||
interval: 1s
|
interval: 1s
|
||||||
timeout: 3s
|
timeout: 3s
|
||||||
retries: 30
|
retries: 30
|
||||||
@ -111,7 +112,7 @@ services:
|
|||||||
# Set the redis password when startup redis server.
|
# Set the redis password when startup redis server.
|
||||||
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
|
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'redis-cli', 'ping']
|
test: [ 'CMD', 'redis-cli', 'ping' ]
|
||||||
|
|
||||||
# The DifySandbox
|
# The DifySandbox
|
||||||
sandbox:
|
sandbox:
|
||||||
@ -131,7 +132,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/sandbox/dependencies:/dependencies
|
- ./volumes/sandbox/dependencies:/dependencies
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health']
|
test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
|
||||||
networks:
|
networks:
|
||||||
- ssrf_proxy_network
|
- ssrf_proxy_network
|
||||||
|
|
||||||
@ -167,12 +168,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
|
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
|
||||||
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
||||||
entrypoint:
|
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||||
[
|
|
||||||
'sh',
|
|
||||||
'-c',
|
|
||||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
|
||||||
]
|
|
||||||
environment:
|
environment:
|
||||||
# pls clearly modify the squid env vars to fit your network environment.
|
# pls clearly modify the squid env vars to fit your network environment.
|
||||||
HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
|
HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
|
||||||
@ -201,8 +197,8 @@ services:
|
|||||||
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
|
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
|
||||||
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
|
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
|
||||||
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
|
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
|
||||||
entrypoint: ['/docker-entrypoint.sh']
|
entrypoint: [ '/docker-entrypoint.sh' ]
|
||||||
command: ['tail', '-f', '/dev/null']
|
command: [ 'tail', '-f', '/dev/null' ]
|
||||||
|
|
||||||
# The nginx reverse proxy.
|
# The nginx reverse proxy.
|
||||||
# used for reverse proxying the API service and Web service.
|
# used for reverse proxying the API service and Web service.
|
||||||
@ -219,12 +215,7 @@ services:
|
|||||||
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
|
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
|
||||||
- ./volumes/certbot/conf:/etc/letsencrypt
|
- ./volumes/certbot/conf:/etc/letsencrypt
|
||||||
- ./volumes/certbot/www:/var/www/html
|
- ./volumes/certbot/www:/var/www/html
|
||||||
entrypoint:
|
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||||
[
|
|
||||||
'sh',
|
|
||||||
'-c',
|
|
||||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
|
||||||
]
|
|
||||||
environment:
|
environment:
|
||||||
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
|
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
|
||||||
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
|
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
|
||||||
@ -316,7 +307,7 @@ services:
|
|||||||
working_dir: /opt/couchbase
|
working_dir: /opt/couchbase
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
tty: true
|
tty: true
|
||||||
entrypoint: [""]
|
entrypoint: [ "" ]
|
||||||
command: sh -c "/opt/couchbase/init/init-cbserver.sh"
|
command: sh -c "/opt/couchbase/init/init-cbserver.sh"
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
|
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
|
||||||
@ -345,7 +336,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/pgvector/data:/var/lib/postgresql/data
|
- ./volumes/pgvector/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'pg_isready']
|
test: [ 'CMD', 'pg_isready' ]
|
||||||
interval: 1s
|
interval: 1s
|
||||||
timeout: 3s
|
timeout: 3s
|
||||||
retries: 30
|
retries: 30
|
||||||
@ -367,7 +358,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
|
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'pg_isready']
|
test: [ 'CMD', 'pg_isready' ]
|
||||||
interval: 1s
|
interval: 1s
|
||||||
timeout: 3s
|
timeout: 3s
|
||||||
retries: 30
|
retries: 30
|
||||||
@ -432,7 +423,7 @@ services:
|
|||||||
- ./volumes/milvus/etcd:/etcd
|
- ./volumes/milvus/etcd:/etcd
|
||||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'etcdctl', 'endpoint', 'health']
|
test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 20s
|
timeout: 20s
|
||||||
retries: 3
|
retries: 3
|
||||||
@ -451,7 +442,7 @@ services:
|
|||||||
- ./volumes/milvus/minio:/minio_data
|
- ./volumes/milvus/minio:/minio_data
|
||||||
command: minio server /minio_data --console-address ":9001"
|
command: minio server /minio_data --console-address ":9001"
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live']
|
test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 20s
|
timeout: 20s
|
||||||
retries: 3
|
retries: 3
|
||||||
@ -463,7 +454,7 @@ services:
|
|||||||
image: milvusdb/milvus:v2.3.1
|
image: milvusdb/milvus:v2.3.1
|
||||||
profiles:
|
profiles:
|
||||||
- milvus
|
- milvus
|
||||||
command: ['milvus', 'run', 'standalone']
|
command: [ 'milvus', 'run', 'standalone' ]
|
||||||
environment:
|
environment:
|
||||||
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
|
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
|
||||||
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
|
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
|
||||||
@ -471,7 +462,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/milvus/milvus:/var/lib/milvus
|
- ./volumes/milvus/milvus:/var/lib/milvus
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz']
|
test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
start_period: 90s
|
start_period: 90s
|
||||||
timeout: 20s
|
timeout: 20s
|
||||||
@ -559,7 +550,7 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- ${ELASTICSEARCH_PORT:-9200}:9200
|
- ${ELASTICSEARCH_PORT:-9200}:9200
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty']
|
test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 50
|
retries: 50
|
||||||
@ -587,7 +578,7 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- ${KIBANA_PORT:-5601}:5601
|
- ${KIBANA_PORT:-5601}:5601
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1']
|
test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
|
@ -27,12 +27,14 @@ x-shared-env: &shared-api-worker-env
|
|||||||
MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true}
|
MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true}
|
||||||
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
|
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
|
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30}
|
||||||
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
|
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
|
||||||
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
|
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
|
||||||
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
|
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
|
||||||
DIFY_PORT: ${DIFY_PORT:-5001}
|
DIFY_PORT: ${DIFY_PORT:-5001}
|
||||||
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-}
|
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1}
|
||||||
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-}
|
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent}
|
||||||
|
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
|
||||||
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
|
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
|
||||||
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
|
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
|
||||||
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
|
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
|
||||||
@ -136,6 +138,7 @@ x-shared-env: &shared-api-worker-env
|
|||||||
MILVUS_TOKEN: ${MILVUS_TOKEN:-}
|
MILVUS_TOKEN: ${MILVUS_TOKEN:-}
|
||||||
MILVUS_USER: ${MILVUS_USER:-root}
|
MILVUS_USER: ${MILVUS_USER:-root}
|
||||||
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus}
|
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus}
|
||||||
|
MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False}
|
||||||
MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
|
MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
|
||||||
MYSCALE_PORT: ${MYSCALE_PORT:-8123}
|
MYSCALE_PORT: ${MYSCALE_PORT:-8123}
|
||||||
MYSCALE_USER: ${MYSCALE_USER:-default}
|
MYSCALE_USER: ${MYSCALE_USER:-default}
|
||||||
@ -401,6 +404,7 @@ x-shared-env: &shared-api-worker-env
|
|||||||
ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}}
|
ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}}
|
||||||
MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true}
|
MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true}
|
||||||
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
||||||
|
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10}
|
||||||
|
|
||||||
services:
|
services:
|
||||||
# API service
|
# API service
|
||||||
@ -476,6 +480,7 @@ services:
|
|||||||
CSP_WHITELIST: ${CSP_WHITELIST:-}
|
CSP_WHITELIST: ${CSP_WHITELIST:-}
|
||||||
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
|
||||||
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
|
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
|
||||||
|
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
|
||||||
|
|
||||||
# The postgres database.
|
# The postgres database.
|
||||||
db:
|
db:
|
||||||
@ -495,7 +500,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/db/data:/var/lib/postgresql/data
|
- ./volumes/db/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'pg_isready']
|
test: [ 'CMD', 'pg_isready' ]
|
||||||
interval: 1s
|
interval: 1s
|
||||||
timeout: 3s
|
timeout: 3s
|
||||||
retries: 30
|
retries: 30
|
||||||
@ -514,7 +519,7 @@ services:
|
|||||||
# Set the redis password when startup redis server.
|
# Set the redis password when startup redis server.
|
||||||
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
|
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'redis-cli', 'ping']
|
test: [ 'CMD', 'redis-cli', 'ping' ]
|
||||||
|
|
||||||
# The DifySandbox
|
# The DifySandbox
|
||||||
sandbox:
|
sandbox:
|
||||||
@ -534,7 +539,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/sandbox/dependencies:/dependencies
|
- ./volumes/sandbox/dependencies:/dependencies
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health']
|
test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
|
||||||
networks:
|
networks:
|
||||||
- ssrf_proxy_network
|
- ssrf_proxy_network
|
||||||
|
|
||||||
@ -571,12 +576,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
|
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
|
||||||
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
||||||
entrypoint:
|
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||||
[
|
|
||||||
'sh',
|
|
||||||
'-c',
|
|
||||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
|
||||||
]
|
|
||||||
environment:
|
environment:
|
||||||
# pls clearly modify the squid env vars to fit your network environment.
|
# pls clearly modify the squid env vars to fit your network environment.
|
||||||
HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
|
HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
|
||||||
@ -605,8 +605,8 @@ services:
|
|||||||
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
|
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
|
||||||
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
|
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
|
||||||
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
|
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
|
||||||
entrypoint: ['/docker-entrypoint.sh']
|
entrypoint: [ '/docker-entrypoint.sh' ]
|
||||||
command: ['tail', '-f', '/dev/null']
|
command: [ 'tail', '-f', '/dev/null' ]
|
||||||
|
|
||||||
# The nginx reverse proxy.
|
# The nginx reverse proxy.
|
||||||
# used for reverse proxying the API service and Web service.
|
# used for reverse proxying the API service and Web service.
|
||||||
@ -623,12 +623,7 @@ services:
|
|||||||
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
|
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
|
||||||
- ./volumes/certbot/conf:/etc/letsencrypt
|
- ./volumes/certbot/conf:/etc/letsencrypt
|
||||||
- ./volumes/certbot/www:/var/www/html
|
- ./volumes/certbot/www:/var/www/html
|
||||||
entrypoint:
|
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
|
||||||
[
|
|
||||||
'sh',
|
|
||||||
'-c',
|
|
||||||
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
|
|
||||||
]
|
|
||||||
environment:
|
environment:
|
||||||
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
|
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
|
||||||
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
|
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
|
||||||
@ -720,7 +715,7 @@ services:
|
|||||||
working_dir: /opt/couchbase
|
working_dir: /opt/couchbase
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
tty: true
|
tty: true
|
||||||
entrypoint: [""]
|
entrypoint: [ "" ]
|
||||||
command: sh -c "/opt/couchbase/init/init-cbserver.sh"
|
command: sh -c "/opt/couchbase/init/init-cbserver.sh"
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
|
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
|
||||||
@ -749,7 +744,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/pgvector/data:/var/lib/postgresql/data
|
- ./volumes/pgvector/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'pg_isready']
|
test: [ 'CMD', 'pg_isready' ]
|
||||||
interval: 1s
|
interval: 1s
|
||||||
timeout: 3s
|
timeout: 3s
|
||||||
retries: 30
|
retries: 30
|
||||||
@ -771,7 +766,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
|
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'pg_isready']
|
test: [ 'CMD', 'pg_isready' ]
|
||||||
interval: 1s
|
interval: 1s
|
||||||
timeout: 3s
|
timeout: 3s
|
||||||
retries: 30
|
retries: 30
|
||||||
@ -836,7 +831,7 @@ services:
|
|||||||
- ./volumes/milvus/etcd:/etcd
|
- ./volumes/milvus/etcd:/etcd
|
||||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'etcdctl', 'endpoint', 'health']
|
test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 20s
|
timeout: 20s
|
||||||
retries: 3
|
retries: 3
|
||||||
@ -855,7 +850,7 @@ services:
|
|||||||
- ./volumes/milvus/minio:/minio_data
|
- ./volumes/milvus/minio:/minio_data
|
||||||
command: minio server /minio_data --console-address ":9001"
|
command: minio server /minio_data --console-address ":9001"
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live']
|
test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 20s
|
timeout: 20s
|
||||||
retries: 3
|
retries: 3
|
||||||
@ -864,10 +859,10 @@ services:
|
|||||||
|
|
||||||
milvus-standalone:
|
milvus-standalone:
|
||||||
container_name: milvus-standalone
|
container_name: milvus-standalone
|
||||||
image: milvusdb/milvus:v2.3.1
|
image: milvusdb/milvus:v2.5.0-beta
|
||||||
profiles:
|
profiles:
|
||||||
- milvus
|
- milvus
|
||||||
command: ['milvus', 'run', 'standalone']
|
command: [ 'milvus', 'run', 'standalone' ]
|
||||||
environment:
|
environment:
|
||||||
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
|
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
|
||||||
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
|
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
|
||||||
@ -875,7 +870,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./volumes/milvus/milvus:/var/lib/milvus
|
- ./volumes/milvus/milvus:/var/lib/milvus
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz']
|
test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
start_period: 90s
|
start_period: 90s
|
||||||
timeout: 20s
|
timeout: 20s
|
||||||
@ -948,22 +943,30 @@ services:
|
|||||||
container_name: elasticsearch
|
container_name: elasticsearch
|
||||||
profiles:
|
profiles:
|
||||||
- elasticsearch
|
- elasticsearch
|
||||||
|
- elasticsearch-ja
|
||||||
restart: always
|
restart: always
|
||||||
volumes:
|
volumes:
|
||||||
|
- ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh
|
||||||
- dify_es01_data:/usr/share/elasticsearch/data
|
- dify_es01_data:/usr/share/elasticsearch/data
|
||||||
environment:
|
environment:
|
||||||
ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
|
ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
|
||||||
|
VECTOR_STORE: ${VECTOR_STORE:-}
|
||||||
cluster.name: dify-es-cluster
|
cluster.name: dify-es-cluster
|
||||||
node.name: dify-es0
|
node.name: dify-es0
|
||||||
discovery.type: single-node
|
discovery.type: single-node
|
||||||
xpack.license.self_generated.type: trial
|
xpack.license.self_generated.type: basic
|
||||||
xpack.security.enabled: 'true'
|
xpack.security.enabled: 'true'
|
||||||
xpack.security.enrollment.enabled: 'false'
|
xpack.security.enrollment.enabled: 'false'
|
||||||
xpack.security.http.ssl.enabled: 'false'
|
xpack.security.http.ssl.enabled: 'false'
|
||||||
ports:
|
ports:
|
||||||
- ${ELASTICSEARCH_PORT:-9200}:9200
|
- ${ELASTICSEARCH_PORT:-9200}:9200
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
memory: 2g
|
||||||
|
entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ]
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty']
|
test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 50
|
retries: 50
|
||||||
@ -991,7 +994,7 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- ${KIBANA_PORT:-5601}:5601
|
- ${KIBANA_PORT:-5601}:5601
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1']
|
test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
|
25
docker/elasticsearch/docker-entrypoint.sh
Executable file
25
docker/elasticsearch/docker-entrypoint.sh
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [ "${VECTOR_STORE}" = "elasticsearch-ja" ]; then
|
||||||
|
# Check if the ICU tokenizer plugin is installed
|
||||||
|
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-icu; then
|
||||||
|
printf '%s\n' "Installing the ICU tokenizer plugin"
|
||||||
|
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-icu; then
|
||||||
|
printf '%s\n' "Failed to install the ICU tokenizer plugin"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
# Check if the Japanese language analyzer plugin is installed
|
||||||
|
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-kuromoji; then
|
||||||
|
printf '%s\n' "Installing the Japanese language analyzer plugin"
|
||||||
|
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-kuromoji; then
|
||||||
|
printf '%s\n' "Failed to install the Japanese language analyzer plugin"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run the original entrypoint script
|
||||||
|
exec /bin/tini -- /usr/local/bin/docker-entrypoint.sh
|
@ -25,3 +25,6 @@ NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
|
|||||||
|
|
||||||
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
|
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
|
||||||
NEXT_PUBLIC_CSP_WHITELIST=
|
NEXT_PUBLIC_CSP_WHITELIST=
|
||||||
|
|
||||||
|
# The maximum number of top-k value for RAG.
|
||||||
|
NEXT_PUBLIC_TOP_K_MAX_VALUE=10
|
||||||
|
@ -25,16 +25,18 @@ import Input from '@/app/components/base/input'
|
|||||||
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
|
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
|
||||||
import TagManagementModal from '@/app/components/base/tag-management'
|
import TagManagementModal from '@/app/components/base/tag-management'
|
||||||
import TagFilter from '@/app/components/base/tag-management/filter'
|
import TagFilter from '@/app/components/base/tag-management/filter'
|
||||||
|
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
|
||||||
|
|
||||||
const getKey = (
|
const getKey = (
|
||||||
pageIndex: number,
|
pageIndex: number,
|
||||||
previousPageData: AppListResponse,
|
previousPageData: AppListResponse,
|
||||||
activeTab: string,
|
activeTab: string,
|
||||||
|
isCreatedByMe: boolean,
|
||||||
tags: string[],
|
tags: string[],
|
||||||
keywords: string,
|
keywords: string,
|
||||||
) => {
|
) => {
|
||||||
if (!pageIndex || previousPageData.has_more) {
|
if (!pageIndex || previousPageData.has_more) {
|
||||||
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords } }
|
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } }
|
||||||
|
|
||||||
if (activeTab !== 'all')
|
if (activeTab !== 'all')
|
||||||
params.params.mode = activeTab
|
params.params.mode = activeTab
|
||||||
@ -58,6 +60,7 @@ const Apps = () => {
|
|||||||
defaultTab: 'all',
|
defaultTab: 'all',
|
||||||
})
|
})
|
||||||
const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState()
|
const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState()
|
||||||
|
const [isCreatedByMe, setIsCreatedByMe] = useState(false)
|
||||||
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
|
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
|
||||||
const [searchKeywords, setSearchKeywords] = useState(keywords)
|
const [searchKeywords, setSearchKeywords] = useState(keywords)
|
||||||
const setKeywords = useCallback((keywords: string) => {
|
const setKeywords = useCallback((keywords: string) => {
|
||||||
@ -68,7 +71,7 @@ const Apps = () => {
|
|||||||
}, [setQuery])
|
}, [setQuery])
|
||||||
|
|
||||||
const { data, isLoading, setSize, mutate } = useSWRInfinite(
|
const { data, isLoading, setSize, mutate } = useSWRInfinite(
|
||||||
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords),
|
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords),
|
||||||
fetchAppList,
|
fetchAppList,
|
||||||
{ revalidateFirstPage: true },
|
{ revalidateFirstPage: true },
|
||||||
)
|
)
|
||||||
@ -132,6 +135,12 @@ const Apps = () => {
|
|||||||
options={options}
|
options={options}
|
||||||
/>
|
/>
|
||||||
<div className='flex items-center gap-2'>
|
<div className='flex items-center gap-2'>
|
||||||
|
<CheckboxWithLabel
|
||||||
|
className='mr-2'
|
||||||
|
label={t('app.showMyCreatedAppsOnly')}
|
||||||
|
isChecked={isCreatedByMe}
|
||||||
|
onChange={() => setIsCreatedByMe(!isCreatedByMe)}
|
||||||
|
/>
|
||||||
<TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} />
|
<TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} />
|
||||||
<Input
|
<Input
|
||||||
showLeftIcon
|
showLeftIcon
|
||||||
|
@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
|
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
|
||||||
- <code>economy</code> Economy: Build using inverted index of keyword table index
|
- <code>economy</code> Economy: Build using inverted index of keyword table index
|
||||||
</Property>
|
</Property>
|
||||||
|
<Property name='doc_form' type='string' key='doc_form'>
|
||||||
|
Format of indexed content
|
||||||
|
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
|
||||||
|
- <code>hierarchical_model</code> Parent-child mode
|
||||||
|
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
|
||||||
|
</Property>
|
||||||
|
<Property name='doc_language' type='string' key='doc_language'>
|
||||||
|
In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
|
||||||
|
</Property>
|
||||||
<Property name='process_rule' type='object' key='process_rule'>
|
<Property name='process_rule' type='object' key='process_rule'>
|
||||||
Processing rules
|
Processing rules
|
||||||
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
|
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
|
||||||
@ -65,6 +74,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>segmentation</code> (object) Segmentation rules
|
- <code>segmentation</code> (object) Segmentation rules
|
||||||
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
||||||
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
||||||
|
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
|
||||||
|
- <code>subchunk_segmentation</code> (object) Child chunk rules
|
||||||
|
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
|
||||||
|
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
|
||||||
|
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
|
||||||
</Property>
|
</Property>
|
||||||
</Properties>
|
</Properties>
|
||||||
</Col>
|
</Col>
|
||||||
@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
|
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
|
||||||
- <code>economy</code> Economy: Build using inverted index of keyword table index
|
- <code>economy</code> Economy: Build using inverted index of keyword table index
|
||||||
|
|
||||||
|
- <code>doc_form</code> Format of indexed content
|
||||||
|
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
|
||||||
|
- <code>hierarchical_model</code> Parent-child mode
|
||||||
|
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
|
||||||
|
|
||||||
|
- <code>doc_language</code> In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
|
||||||
|
|
||||||
- <code>process_rule</code> Processing rules
|
- <code>process_rule</code> Processing rules
|
||||||
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
|
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
|
||||||
- <code>rules</code> (object) Custom rules (in automatic mode, this field is empty)
|
- <code>rules</code> (object) Custom rules (in automatic mode, this field is empty)
|
||||||
@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>segmentation</code> (object) Segmentation rules
|
- <code>segmentation</code> (object) Segmentation rules
|
||||||
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
||||||
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
||||||
|
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
|
||||||
|
- <code>subchunk_segmentation</code> (object) Child chunk rules
|
||||||
|
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
|
||||||
|
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
|
||||||
|
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='file' type='multipart/form-data' key='file'>
|
<Property name='file' type='multipart/form-data' key='file'>
|
||||||
Files that need to be uploaded.
|
Files that need to be uploaded.
|
||||||
@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>segmentation</code> (object) Segmentation rules
|
- <code>segmentation</code> (object) Segmentation rules
|
||||||
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
||||||
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
||||||
|
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
|
||||||
|
- <code>subchunk_segmentation</code> (object) Child chunk rules
|
||||||
|
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
|
||||||
|
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
|
||||||
|
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
|
||||||
</Property>
|
</Property>
|
||||||
</Properties>
|
</Properties>
|
||||||
</Col>
|
</Col>
|
||||||
@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>segmentation</code> (object) Segmentation rules
|
- <code>segmentation</code> (object) Segmentation rules
|
||||||
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
|
||||||
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
- <code>max_tokens</code> Maximum length (token) defaults to 1000
|
||||||
|
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
|
||||||
|
- <code>subchunk_segmentation</code> (object) Child chunk rules
|
||||||
|
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
|
||||||
|
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
|
||||||
|
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
|
||||||
</Property>
|
</Property>
|
||||||
</Properties>
|
</Properties>
|
||||||
</Col>
|
</Col>
|
||||||
@ -984,7 +1020,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
<Heading
|
<Heading
|
||||||
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
|
||||||
method='POST'
|
method='POST'
|
||||||
title='Update a Chunk in a Document '
|
title='Update a Chunk in a Document'
|
||||||
name='#update_segment'
|
name='#update_segment'
|
||||||
/>
|
/>
|
||||||
<Row>
|
<Row>
|
||||||
@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional)
|
- <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional)
|
||||||
- <code>keywords</code> (list) Keyword (optional)
|
- <code>keywords</code> (list) Keyword (optional)
|
||||||
- <code>enabled</code> (bool) False / true (optional)
|
- <code>enabled</code> (bool) False / true (optional)
|
||||||
|
- <code>regenerate_child_chunks</code> (bool) Whether to regenerate child chunks (optional)
|
||||||
</Property>
|
</Property>
|
||||||
</Properties>
|
</Properties>
|
||||||
</Col>
|
</Col>
|
||||||
|
@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
|
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
|
||||||
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
|
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
|
||||||
</Property>
|
</Property>
|
||||||
|
<Property name='doc_form' type='string' key='doc_form'>
|
||||||
|
索引内容的形式
|
||||||
|
- <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式
|
||||||
|
- <code>hierarchical_model</code> parent-child 模式
|
||||||
|
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
|
||||||
|
</Property>
|
||||||
|
<Property name='doc_language' type='string' key='doc_language'>
|
||||||
|
在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
|
||||||
|
</Property>
|
||||||
<Property name='process_rule' type='object' key='process_rule'>
|
<Property name='process_rule' type='object' key='process_rule'>
|
||||||
处理规则
|
处理规则
|
||||||
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义
|
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义
|
||||||
@ -63,8 +72,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>remove_urls_emails</code> 删除 URL、电子邮件地址
|
- <code>remove_urls_emails</code> 删除 URL、电子邮件地址
|
||||||
- <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值
|
- <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值
|
||||||
- <code>segmentation</code> (object) 分段规则
|
- <code>segmentation</code> (object) 分段规则
|
||||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 <code>\n</code>
|
||||||
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
||||||
|
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
|
||||||
|
- <code>subchunk_segmentation</code> (object) 子分段规则
|
||||||
|
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
|
||||||
|
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
|
||||||
|
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
|
||||||
</Property>
|
</Property>
|
||||||
</Properties>
|
</Properties>
|
||||||
</Col>
|
</Col>
|
||||||
@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
|
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
|
||||||
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
|
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
|
||||||
|
|
||||||
|
- <code>doc_form</code> 索引内容的形式
|
||||||
|
- <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式
|
||||||
|
- <code>hierarchical_model</code> parent-child 模式
|
||||||
|
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
|
||||||
|
|
||||||
|
- <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
|
||||||
|
|
||||||
- <code>process_rule</code> 处理规则
|
- <code>process_rule</code> 处理规则
|
||||||
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义
|
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义
|
||||||
- <code>rules</code> (object) 自定义规则(自动模式下,该字段为空)
|
- <code>rules</code> (object) 自定义规则(自动模式下,该字段为空)
|
||||||
@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>segmentation</code> (object) 分段规则
|
- <code>segmentation</code> (object) 分段规则
|
||||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
||||||
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
||||||
|
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
|
||||||
|
- <code>subchunk_segmentation</code> (object) 子分段规则
|
||||||
|
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
|
||||||
|
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
|
||||||
|
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='file' type='multipart/form-data' key='file'>
|
<Property name='file' type='multipart/form-data' key='file'>
|
||||||
需要上传的文件。
|
需要上传的文件。
|
||||||
@ -411,7 +437,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
<Heading
|
<Heading
|
||||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
|
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
|
||||||
method='POST'
|
method='POST'
|
||||||
title='通过文本更新文档 '
|
title='通过文本更新文档'
|
||||||
name='#update-by-text'
|
name='#update-by-text'
|
||||||
/>
|
/>
|
||||||
<Row>
|
<Row>
|
||||||
@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>segmentation</code> (object) 分段规则
|
- <code>segmentation</code> (object) 分段规则
|
||||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
||||||
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
||||||
|
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
|
||||||
|
- <code>subchunk_segmentation</code> (object) 子分段规则
|
||||||
|
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
|
||||||
|
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
|
||||||
|
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
|
||||||
</Property>
|
</Property>
|
||||||
</Properties>
|
</Properties>
|
||||||
</Col>
|
</Col>
|
||||||
@ -508,7 +539,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
<Heading
|
<Heading
|
||||||
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
|
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
|
||||||
method='POST'
|
method='POST'
|
||||||
title='通过文件更新文档 '
|
title='通过文件更新文档'
|
||||||
name='#update-by-file'
|
name='#update-by-file'
|
||||||
/>
|
/>
|
||||||
<Row>
|
<Row>
|
||||||
@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>segmentation</code> (object) 分段规则
|
- <code>segmentation</code> (object) 分段规则
|
||||||
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
|
||||||
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
- <code>max_tokens</code> 最大长度(token)默认为 1000
|
||||||
|
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
|
||||||
|
- <code>subchunk_segmentation</code> (object) 子分段规则
|
||||||
|
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
|
||||||
|
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
|
||||||
|
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
|
||||||
</Property>
|
</Property>
|
||||||
</Properties>
|
</Properties>
|
||||||
</Col>
|
</Col>
|
||||||
@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||||||
- <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值
|
- <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值
|
||||||
- <code>keywords</code> (list) 关键字,非必填
|
- <code>keywords</code> (list) 关键字,非必填
|
||||||
- <code>enabled</code> (bool) false/true,非必填
|
- <code>enabled</code> (bool) false/true,非必填
|
||||||
|
- <code>regenerate_child_chunks</code> (bool) 是否重新生成子分段,非必填
|
||||||
</Property>
|
</Property>
|
||||||
</Properties>
|
</Properties>
|
||||||
</Col>
|
</Col>
|
||||||
|
@ -26,13 +26,15 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({
|
|||||||
const [clientY, setClientY] = useState(0)
|
const [clientY, setClientY] = useState(0)
|
||||||
const [isResizing, setIsResizing] = useState(false)
|
const [isResizing, setIsResizing] = useState(false)
|
||||||
const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect)
|
const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect)
|
||||||
|
const [oldHeight, setOldHeight] = useState(height)
|
||||||
|
|
||||||
const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => {
|
const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => {
|
||||||
setClientY(e.clientY)
|
setClientY(e.clientY)
|
||||||
setIsResizing(true)
|
setIsResizing(true)
|
||||||
|
setOldHeight(height)
|
||||||
setPrevUserSelectStyle(getComputedStyle(document.body).userSelect)
|
setPrevUserSelectStyle(getComputedStyle(document.body).userSelect)
|
||||||
document.body.style.userSelect = 'none'
|
document.body.style.userSelect = 'none'
|
||||||
}, [])
|
}, [height])
|
||||||
|
|
||||||
const handleStopResize = useCallback(() => {
|
const handleStopResize = useCallback(() => {
|
||||||
setIsResizing(false)
|
setIsResizing(false)
|
||||||
@ -44,8 +46,7 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({
|
|||||||
return
|
return
|
||||||
|
|
||||||
const offset = e.clientY - clientY
|
const offset = e.clientY - clientY
|
||||||
let newHeight = height + offset
|
let newHeight = oldHeight + offset
|
||||||
setClientY(e.clientY)
|
|
||||||
if (newHeight < minHeight)
|
if (newHeight < minHeight)
|
||||||
newHeight = minHeight
|
newHeight = minHeight
|
||||||
onHeightChange(newHeight)
|
onHeightChange(newHeight)
|
||||||
|
@ -27,6 +27,7 @@ import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/confi
|
|||||||
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block'
|
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block'
|
||||||
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
|
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
|
||||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||||
|
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||||
|
|
||||||
export type ISimplePromptInput = {
|
export type ISimplePromptInput = {
|
||||||
mode: AppType
|
mode: AppType
|
||||||
@ -54,6 +55,11 @@ const Prompt: FC<ISimplePromptInput> = ({
|
|||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const media = useBreakpoints()
|
const media = useBreakpoints()
|
||||||
const isMobile = media === MediaType.mobile
|
const isMobile = media === MediaType.mobile
|
||||||
|
const featuresStore = useFeaturesStore()
|
||||||
|
const {
|
||||||
|
features,
|
||||||
|
setFeatures,
|
||||||
|
} = featuresStore!.getState()
|
||||||
|
|
||||||
const { eventEmitter } = useEventEmitterContextContext()
|
const { eventEmitter } = useEventEmitterContextContext()
|
||||||
const {
|
const {
|
||||||
@ -137,8 +143,18 @@ const Prompt: FC<ISimplePromptInput> = ({
|
|||||||
})
|
})
|
||||||
setModelConfig(newModelConfig)
|
setModelConfig(newModelConfig)
|
||||||
setPrevPromptConfig(modelConfig.configs)
|
setPrevPromptConfig(modelConfig.configs)
|
||||||
if (mode !== AppType.completion)
|
|
||||||
|
if (mode !== AppType.completion) {
|
||||||
setIntroduction(res.opening_statement)
|
setIntroduction(res.opening_statement)
|
||||||
|
const newFeatures = produce(features, (draft) => {
|
||||||
|
draft.opening = {
|
||||||
|
...draft.opening,
|
||||||
|
enabled: !!res.opening_statement,
|
||||||
|
opening_statement: res.opening_statement,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
setFeatures(newFeatures)
|
||||||
|
}
|
||||||
showAutomaticFalse()
|
showAutomaticFalse()
|
||||||
}
|
}
|
||||||
const minHeight = initEditorHeight || 228
|
const minHeight = initEditorHeight || 228
|
||||||
|
@ -59,36 +59,24 @@ const ConfigContent: FC<Props> = ({
|
|||||||
|
|
||||||
const {
|
const {
|
||||||
modelList: rerankModelList,
|
modelList: rerankModelList,
|
||||||
defaultModel: rerankDefaultModel,
|
|
||||||
currentModel: isRerankDefaultModelValid,
|
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
const {
|
const {
|
||||||
currentModel: currentRerankModel,
|
currentModel: currentRerankModel,
|
||||||
} = useCurrentProviderAndModel(
|
} = useCurrentProviderAndModel(
|
||||||
rerankModelList,
|
rerankModelList,
|
||||||
rerankDefaultModel
|
{
|
||||||
? {
|
provider: datasetConfigs.reranking_model?.reranking_provider_name,
|
||||||
...rerankDefaultModel,
|
model: datasetConfigs.reranking_model?.reranking_model_name,
|
||||||
provider: rerankDefaultModel.provider.provider,
|
},
|
||||||
}
|
|
||||||
: undefined,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const rerankModel = (() => {
|
const rerankModel = useMemo(() => {
|
||||||
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
return {
|
||||||
return {
|
provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '',
|
||||||
provider_name: datasetConfigs.reranking_model.reranking_provider_name,
|
model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '',
|
||||||
model_name: datasetConfigs.reranking_model.reranking_model_name,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if (rerankDefaultModel) {
|
}, [datasetConfigs.reranking_model])
|
||||||
return {
|
|
||||||
provider_name: rerankDefaultModel.provider.provider,
|
|
||||||
model_name: rerankDefaultModel.model,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})()
|
|
||||||
|
|
||||||
const handleParamChange = (key: string, value: number) => {
|
const handleParamChange = (key: string, value: number) => {
|
||||||
if (key === 'top_k') {
|
if (key === 'top_k') {
|
||||||
@ -133,6 +121,12 @@ const ConfigContent: FC<Props> = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleRerankModeChange = (mode: RerankingModeEnum) => {
|
const handleRerankModeChange = (mode: RerankingModeEnum) => {
|
||||||
|
if (mode === datasetConfigs.reranking_mode)
|
||||||
|
return
|
||||||
|
|
||||||
|
if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel)
|
||||||
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
|
|
||||||
onChange({
|
onChange({
|
||||||
...datasetConfigs,
|
...datasetConfigs,
|
||||||
reranking_mode: mode,
|
reranking_mode: mode,
|
||||||
@ -162,31 +156,25 @@ const ConfigContent: FC<Props> = ({
|
|||||||
|
|
||||||
const canManuallyToggleRerank = useMemo(() => {
|
const canManuallyToggleRerank = useMemo(() => {
|
||||||
return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|
return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|
||||||
|| selectedDatasetsMode.allExternal
|
|| selectedDatasetsMode.allExternal
|
||||||
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
|
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
|
||||||
|
|
||||||
const showRerankModel = useMemo(() => {
|
const showRerankModel = useMemo(() => {
|
||||||
if (!canManuallyToggleRerank)
|
if (!canManuallyToggleRerank)
|
||||||
return true
|
return true
|
||||||
else if (canManuallyToggleRerank && !isRerankDefaultModelValid)
|
|
||||||
return false
|
|
||||||
|
|
||||||
return datasetConfigs.reranking_enable
|
return datasetConfigs.reranking_enable
|
||||||
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
|
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
|
||||||
|
|
||||||
const handleDisabledSwitchClick = useCallback(() => {
|
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
|
||||||
if (!currentRerankModel && !showRerankModel)
|
if (!currentRerankModel && enable)
|
||||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
}, [currentRerankModel, showRerankModel, t])
|
onChange({
|
||||||
|
...datasetConfigs,
|
||||||
useEffect(() => {
|
reranking_enable: enable,
|
||||||
if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
|
})
|
||||||
onChange({
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
...datasetConfigs,
|
}, [currentRerankModel, datasetConfigs, onChange])
|
||||||
reranking_enable: showRerankModel,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
@ -267,24 +255,12 @@ const ConfigContent: FC<Props> = ({
|
|||||||
<div className='flex items-center'>
|
<div className='flex items-center'>
|
||||||
{
|
{
|
||||||
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
|
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
|
||||||
<div
|
<Switch
|
||||||
className='flex items-center'
|
size='md'
|
||||||
onClick={handleDisabledSwitchClick}
|
defaultValue={showRerankModel}
|
||||||
>
|
disabled={!canManuallyToggleRerank}
|
||||||
<Switch
|
onChange={handleDisabledSwitchClick}
|
||||||
size='md'
|
/>
|
||||||
defaultValue={showRerankModel}
|
|
||||||
disabled={!currentRerankModel || !canManuallyToggleRerank}
|
|
||||||
onChange={(v) => {
|
|
||||||
if (canManuallyToggleRerank) {
|
|
||||||
onChange({
|
|
||||||
...datasetConfigs,
|
|
||||||
reranking_enable: v,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
|
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
|
||||||
@ -298,21 +274,24 @@ const ConfigContent: FC<Props> = ({
|
|||||||
triggerClassName='ml-1 w-4 h-4'
|
triggerClassName='ml-1 w-4 h-4'
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
{
|
||||||
<ModelSelector
|
showRerankModel && (
|
||||||
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
|
<div>
|
||||||
onSelect={(v) => {
|
<ModelSelector
|
||||||
onChange({
|
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
|
||||||
...datasetConfigs,
|
onSelect={(v) => {
|
||||||
reranking_model: {
|
onChange({
|
||||||
reranking_provider_name: v.provider,
|
...datasetConfigs,
|
||||||
reranking_model_name: v.model,
|
reranking_model: {
|
||||||
},
|
reranking_provider_name: v.provider,
|
||||||
})
|
reranking_model_name: v.model,
|
||||||
}}
|
},
|
||||||
modelList={rerankModelList}
|
})
|
||||||
/>
|
}}
|
||||||
</div>
|
modelList={rerankModelList}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@ import Modal from '@/app/components/base/modal'
|
|||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import { RETRIEVE_TYPE } from '@/types/app'
|
import { RETRIEVE_TYPE } from '@/types/app'
|
||||||
import Toast from '@/app/components/base/toast'
|
import Toast from '@/app/components/base/toast'
|
||||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import { RerankingModeEnum } from '@/models/datasets'
|
import { RerankingModeEnum } from '@/models/datasets'
|
||||||
import type { DataSet } from '@/models/datasets'
|
import type { DataSet } from '@/models/datasets'
|
||||||
@ -41,17 +41,27 @@ const ParamsConfig = ({
|
|||||||
}, [datasetConfigs])
|
}, [datasetConfigs])
|
||||||
|
|
||||||
const {
|
const {
|
||||||
defaultModel: rerankDefaultModel,
|
modelList: rerankModelList,
|
||||||
currentModel: isRerankDefaultModelValid,
|
currentModel: rerankDefaultModel,
|
||||||
currentProvider: rerankDefaultProvider,
|
currentProvider: rerankDefaultProvider,
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
|
const {
|
||||||
|
currentModel: isCurrentRerankModelValid,
|
||||||
|
} = useCurrentProviderAndModel(
|
||||||
|
rerankModelList,
|
||||||
|
{
|
||||||
|
provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '',
|
||||||
|
model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
const isValid = () => {
|
const isValid = () => {
|
||||||
let errMsg = ''
|
let errMsg = ''
|
||||||
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
|
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
|
||||||
if (tempDataSetConfigs.reranking_enable
|
if (tempDataSetConfigs.reranking_enable
|
||||||
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
|
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
|
||||||
&& !isRerankDefaultModelValid
|
&& !isCurrentRerankModelValid
|
||||||
)
|
)
|
||||||
errMsg = t('appDebug.datasetConfig.rerankModelRequired')
|
errMsg = t('appDebug.datasetConfig.rerankModelRequired')
|
||||||
}
|
}
|
||||||
@ -66,16 +76,7 @@ const ParamsConfig = ({
|
|||||||
const handleSave = () => {
|
const handleSave = () => {
|
||||||
if (!isValid())
|
if (!isValid())
|
||||||
return
|
return
|
||||||
const config = { ...tempDataSetConfigs }
|
setDatasetConfigs(tempDataSetConfigs)
|
||||||
if (config.retrieval_model === RETRIEVE_TYPE.multiWay
|
|
||||||
&& config.reranking_mode === RerankingModeEnum.RerankingModel
|
|
||||||
&& !config.reranking_model) {
|
|
||||||
config.reranking_model = {
|
|
||||||
reranking_provider_name: rerankDefaultModel?.provider?.provider,
|
|
||||||
reranking_model_name: rerankDefaultModel?.model,
|
|
||||||
} as any
|
|
||||||
}
|
|
||||||
setDatasetConfigs(config)
|
|
||||||
setRerankSettingModalOpen(false)
|
setRerankSettingModalOpen(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,14 +95,14 @@ const ParamsConfig = ({
|
|||||||
reranking_enable: restConfigs.reranking_enable,
|
reranking_enable: restConfigs.reranking_enable,
|
||||||
}, selectedDatasets, selectedDatasets, {
|
}, selectedDatasets, selectedDatasets, {
|
||||||
provider: rerankDefaultProvider?.provider,
|
provider: rerankDefaultProvider?.provider,
|
||||||
model: isRerankDefaultModelValid?.model,
|
model: rerankDefaultModel?.model,
|
||||||
})
|
})
|
||||||
|
|
||||||
setTempDataSetConfigs({
|
setTempDataSetConfigs({
|
||||||
...retrievalConfig,
|
...retrievalConfig,
|
||||||
reranking_model: restConfigs.reranking_model && {
|
reranking_model: {
|
||||||
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
|
reranking_provider_name: retrievalConfig.reranking_model?.provider || '',
|
||||||
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
|
reranking_model_name: retrievalConfig.reranking_model?.model || '',
|
||||||
},
|
},
|
||||||
retrieval_model,
|
retrieval_model,
|
||||||
score_threshold_enabled,
|
score_threshold_enabled,
|
||||||
|
@ -29,7 +29,7 @@ const WeightedScore = ({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
<div className='px-3 pt-5 h-[52px] space-x-3 rounded-lg border border-components-panel-border'>
|
<div className='px-3 pt-5 pb-2 space-x-3 rounded-lg border border-components-panel-border'>
|
||||||
<Slider
|
<Slider
|
||||||
className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')}
|
className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')}
|
||||||
max={1.0}
|
max={1.0}
|
||||||
@ -39,7 +39,7 @@ const WeightedScore = ({
|
|||||||
onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })}
|
onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })}
|
||||||
trackClassName='weightedScoreSliderTrack'
|
trackClassName='weightedScoreSliderTrack'
|
||||||
/>
|
/>
|
||||||
<div className='flex justify-between mt-1'>
|
<div className='flex justify-between mt-3'>
|
||||||
<div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'>
|
<div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'>
|
||||||
<div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}>
|
<div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}>
|
||||||
{t('dataset.weightedScore.semantic')}
|
{t('dataset.weightedScore.semantic')}
|
||||||
|
@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider'
|
|||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import Input from '@/app/components/base/input'
|
import Input from '@/app/components/base/input'
|
||||||
import Textarea from '@/app/components/base/textarea'
|
import Textarea from '@/app/components/base/textarea'
|
||||||
import { type DataSet, RerankingModeEnum } from '@/models/datasets'
|
import { type DataSet } from '@/models/datasets'
|
||||||
import { useToastContext } from '@/app/components/base/toast'
|
import { useToastContext } from '@/app/components/base/toast'
|
||||||
import { updateDatasetSetting } from '@/service/datasets'
|
import { updateDatasetSetting } from '@/service/datasets'
|
||||||
import { useAppContext } from '@/context/app-context'
|
import { useAppContext } from '@/context/app-context'
|
||||||
@ -21,7 +21,7 @@ import type { RetrievalConfig } from '@/types/app'
|
|||||||
import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings'
|
import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings'
|
||||||
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
||||||
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
||||||
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
||||||
import PermissionSelector from '@/app/components/datasets/settings/permission-selector'
|
import PermissionSelector from '@/app/components/datasets/settings/permission-selector'
|
||||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||||
@ -99,8 +99,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
|||||||
}
|
}
|
||||||
if (
|
if (
|
||||||
!isReRankModelSelected({
|
!isReRankModelSelected({
|
||||||
rerankDefaultModel,
|
|
||||||
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
|
|
||||||
rerankModelList,
|
rerankModelList,
|
||||||
retrievalConfig,
|
retrievalConfig,
|
||||||
indexMethod,
|
indexMethod,
|
||||||
@ -109,14 +107,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
|||||||
notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
|
notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const postRetrievalConfig = ensureRerankModelSelected({
|
|
||||||
rerankDefaultModel: rerankDefaultModel!,
|
|
||||||
retrievalConfig: {
|
|
||||||
...retrievalConfig,
|
|
||||||
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
|
|
||||||
},
|
|
||||||
indexMethod,
|
|
||||||
})
|
|
||||||
try {
|
try {
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
const { id, name, description, permission } = localeCurrentDataset
|
const { id, name, description, permission } = localeCurrentDataset
|
||||||
@ -128,8 +118,8 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
|||||||
permission,
|
permission,
|
||||||
indexing_technique: indexMethod,
|
indexing_technique: indexMethod,
|
||||||
retrieval_model: {
|
retrieval_model: {
|
||||||
...postRetrievalConfig,
|
...retrievalConfig,
|
||||||
score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0,
|
score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0,
|
||||||
},
|
},
|
||||||
embedding_model: localeCurrentDataset.embedding_model,
|
embedding_model: localeCurrentDataset.embedding_model,
|
||||||
embedding_model_provider: localeCurrentDataset.embedding_model_provider,
|
embedding_model_provider: localeCurrentDataset.embedding_model_provider,
|
||||||
@ -157,7 +147,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
|||||||
onSave({
|
onSave({
|
||||||
...localeCurrentDataset,
|
...localeCurrentDataset,
|
||||||
indexing_technique: indexMethod,
|
indexing_technique: indexMethod,
|
||||||
retrieval_model_dict: postRetrievalConfig,
|
retrieval_model_dict: retrievalConfig,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
catch (e) {
|
catch (e) {
|
||||||
|
@ -287,9 +287,9 @@ const Configuration: FC = () => {
|
|||||||
|
|
||||||
setDatasetConfigs({
|
setDatasetConfigs({
|
||||||
...retrievalConfig,
|
...retrievalConfig,
|
||||||
reranking_model: restConfigs.reranking_model && {
|
reranking_model: {
|
||||||
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
|
reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
|
||||||
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
|
reranking_model_name: retrievalConfig?.reranking_model?.model || '',
|
||||||
},
|
},
|
||||||
retrieval_model,
|
retrieval_model,
|
||||||
score_threshold_enabled,
|
score_threshold_enabled,
|
||||||
|
@ -39,6 +39,7 @@ type ChatInputAreaProps = {
|
|||||||
inputs?: Record<string, any>
|
inputs?: Record<string, any>
|
||||||
inputsForm?: InputForm[]
|
inputsForm?: InputForm[]
|
||||||
theme?: Theme | null
|
theme?: Theme | null
|
||||||
|
isResponding?: boolean
|
||||||
}
|
}
|
||||||
const ChatInputArea = ({
|
const ChatInputArea = ({
|
||||||
showFeatureBar,
|
showFeatureBar,
|
||||||
@ -51,6 +52,7 @@ const ChatInputArea = ({
|
|||||||
inputs = {},
|
inputs = {},
|
||||||
inputsForm = [],
|
inputsForm = [],
|
||||||
theme,
|
theme,
|
||||||
|
isResponding,
|
||||||
}: ChatInputAreaProps) => {
|
}: ChatInputAreaProps) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { notify } = useToastContext()
|
const { notify } = useToastContext()
|
||||||
@ -77,6 +79,11 @@ const ChatInputArea = ({
|
|||||||
const historyRef = useRef([''])
|
const historyRef = useRef([''])
|
||||||
const [currentIndex, setCurrentIndex] = useState(-1)
|
const [currentIndex, setCurrentIndex] = useState(-1)
|
||||||
const handleSend = () => {
|
const handleSend = () => {
|
||||||
|
if (isResponding) {
|
||||||
|
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if (onSend) {
|
if (onSend) {
|
||||||
const { files, setFiles } = filesStore.getState()
|
const { files, setFiles } = filesStore.getState()
|
||||||
if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) {
|
if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) {
|
||||||
@ -116,7 +123,7 @@ const ChatInputArea = ({
|
|||||||
setQuery(historyRef.current[currentIndex + 1])
|
setQuery(historyRef.current[currentIndex + 1])
|
||||||
}
|
}
|
||||||
else if (currentIndex === historyRef.current.length - 1) {
|
else if (currentIndex === historyRef.current.length - 1) {
|
||||||
// If it is the last element, clear the input box
|
// If it is the last element, clear the input box
|
||||||
setCurrentIndex(historyRef.current.length)
|
setCurrentIndex(historyRef.current.length)
|
||||||
setQuery('')
|
setQuery('')
|
||||||
}
|
}
|
||||||
@ -169,6 +176,7 @@ const ChatInputArea = ({
|
|||||||
'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none',
|
'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none',
|
||||||
)}
|
)}
|
||||||
placeholder={t('common.chat.inputPlaceholder') || ''}
|
placeholder={t('common.chat.inputPlaceholder') || ''}
|
||||||
|
autoFocus
|
||||||
autoSize={{ minRows: 1 }}
|
autoSize={{ minRows: 1 }}
|
||||||
onResize={handleTextareaResize}
|
onResize={handleTextareaResize}
|
||||||
value={query}
|
value={query}
|
||||||
|
@ -292,6 +292,7 @@ const Chat: FC<ChatProps> = ({
|
|||||||
inputs={inputs}
|
inputs={inputs}
|
||||||
inputsForm={inputsForm}
|
inputsForm={inputsForm}
|
||||||
theme={themeBuilder?.theme}
|
theme={themeBuilder?.theme}
|
||||||
|
isResponding={isResponding}
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -28,8 +28,8 @@ const Question: FC<QuestionProps> = ({
|
|||||||
} = item
|
} = item
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='flex justify-end mb-2 last:mb-0 pl-10'>
|
<div className='flex justify-end mb-2 last:mb-0 pl-14'>
|
||||||
<div className='group relative mr-4'>
|
<div className='group relative mr-4 max-w-full'>
|
||||||
<div
|
<div
|
||||||
className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900'
|
className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900'
|
||||||
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
|
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
|
||||||
|
@ -111,9 +111,9 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props }
|
|||||||
}
|
}
|
||||||
else if (language === 'echarts') {
|
else if (language === 'echarts') {
|
||||||
return (
|
return (
|
||||||
<div style={{ minHeight: '350px', minWidth: '700px' }}>
|
<div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}>
|
||||||
<ErrorBoundary>
|
<ErrorBoundary>
|
||||||
<ReactEcharts option={chartData} />
|
<ReactEcharts option={chartData} style={{ minWidth: '700px' }} />
|
||||||
</ErrorBoundary>
|
</ErrorBoundary>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
@ -11,11 +11,17 @@ type Props = {
|
|||||||
enable: boolean
|
enable: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maxTopK = (() => {
|
||||||
|
const configValue = parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10)
|
||||||
|
if (configValue && !isNaN(configValue))
|
||||||
|
return configValue
|
||||||
|
return 10
|
||||||
|
})()
|
||||||
const VALUE_LIMIT = {
|
const VALUE_LIMIT = {
|
||||||
default: 2,
|
default: 2,
|
||||||
step: 1,
|
step: 1,
|
||||||
min: 1,
|
min: 1,
|
||||||
max: 10,
|
max: maxTopK,
|
||||||
}
|
}
|
||||||
|
|
||||||
const key = 'top_k'
|
const key = 'top_k'
|
||||||
|
@ -6,14 +6,10 @@ import type {
|
|||||||
import { RerankingModeEnum } from '@/models/datasets'
|
import { RerankingModeEnum } from '@/models/datasets'
|
||||||
|
|
||||||
export const isReRankModelSelected = ({
|
export const isReRankModelSelected = ({
|
||||||
rerankDefaultModel,
|
|
||||||
isRerankDefaultModelValid,
|
|
||||||
retrievalConfig,
|
retrievalConfig,
|
||||||
rerankModelList,
|
rerankModelList,
|
||||||
indexMethod,
|
indexMethod,
|
||||||
}: {
|
}: {
|
||||||
rerankDefaultModel?: DefaultModelResponse
|
|
||||||
isRerankDefaultModelValid: boolean
|
|
||||||
retrievalConfig: RetrievalConfig
|
retrievalConfig: RetrievalConfig
|
||||||
rerankModelList: Model[]
|
rerankModelList: Model[]
|
||||||
indexMethod?: string
|
indexMethod?: string
|
||||||
@ -25,12 +21,17 @@ export const isReRankModelSelected = ({
|
|||||||
return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
|
return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isRerankDefaultModelValid)
|
|
||||||
return !!rerankDefaultModel
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
})()
|
})()
|
||||||
|
|
||||||
|
if (
|
||||||
|
indexMethod === 'high_quality'
|
||||||
|
&& ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrievalConfig.search_method))
|
||||||
|
&& retrievalConfig.reranking_enable
|
||||||
|
&& !rerankModelSelected
|
||||||
|
)
|
||||||
|
return false
|
||||||
|
|
||||||
if (
|
if (
|
||||||
indexMethod === 'high_quality'
|
indexMethod === 'high_quality'
|
||||||
&& (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore)
|
&& (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore)
|
||||||
|
@ -10,11 +10,13 @@ import { RETRIEVE_METHOD } from '@/types/app'
|
|||||||
import type { RetrievalConfig } from '@/types/app'
|
import type { RetrievalConfig } from '@/types/app'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
|
disabled?: boolean
|
||||||
value: RetrievalConfig
|
value: RetrievalConfig
|
||||||
onChange: (value: RetrievalConfig) => void
|
onChange: (value: RetrievalConfig) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
const EconomicalRetrievalMethodConfig: FC<Props> = ({
|
const EconomicalRetrievalMethodConfig: FC<Props> = ({
|
||||||
|
disabled = false,
|
||||||
value,
|
value,
|
||||||
onChange,
|
onChange,
|
||||||
}) => {
|
}) => {
|
||||||
@ -22,7 +24,8 @@ const EconomicalRetrievalMethodConfig: FC<Props> = ({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='space-y-2'>
|
<div className='space-y-2'>
|
||||||
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
|
<OptionCard
|
||||||
|
disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
|
||||||
title={t('dataset.retrieval.invertedIndex.title')}
|
title={t('dataset.retrieval.invertedIndex.title')}
|
||||||
description={t('dataset.retrieval.invertedIndex.description')} isActive
|
description={t('dataset.retrieval.invertedIndex.description')} isActive
|
||||||
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import React from 'react'
|
import React, { useCallback } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import Image from 'next/image'
|
import Image from 'next/image'
|
||||||
import RetrievalParamConfig from '../retrieval-param-config'
|
import RetrievalParamConfig from '../retrieval-param-config'
|
||||||
@ -10,7 +10,7 @@ import { retrievalIcon } from '../../create/icons'
|
|||||||
import type { RetrievalConfig } from '@/types/app'
|
import type { RetrievalConfig } from '@/types/app'
|
||||||
import { RETRIEVE_METHOD } from '@/types/app'
|
import { RETRIEVE_METHOD } from '@/types/app'
|
||||||
import { useProviderContext } from '@/context/provider-context'
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import {
|
import {
|
||||||
DEFAULT_WEIGHTED_SCORE,
|
DEFAULT_WEIGHTED_SCORE,
|
||||||
@ -20,54 +20,87 @@ import {
|
|||||||
import Badge from '@/app/components/base/badge'
|
import Badge from '@/app/components/base/badge'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
|
disabled?: boolean
|
||||||
value: RetrievalConfig
|
value: RetrievalConfig
|
||||||
onChange: (value: RetrievalConfig) => void
|
onChange: (value: RetrievalConfig) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
const RetrievalMethodConfig: FC<Props> = ({
|
const RetrievalMethodConfig: FC<Props> = ({
|
||||||
value: passValue,
|
disabled = false,
|
||||||
|
value,
|
||||||
onChange,
|
onChange,
|
||||||
}) => {
|
}) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { supportRetrievalMethods } = useProviderContext()
|
const { supportRetrievalMethods } = useProviderContext()
|
||||||
const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank)
|
const {
|
||||||
const value = (() => {
|
defaultModel: rerankDefaultModel,
|
||||||
if (!passValue.reranking_model.reranking_model_name) {
|
currentModel: isRerankDefaultModelValid,
|
||||||
return {
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
...passValue,
|
|
||||||
reranking_model: {
|
const onSwitch = useCallback((retrieveMethod: RETRIEVE_METHOD) => {
|
||||||
reranking_provider_name: rerankDefaultModel?.provider.provider || '',
|
if ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrieveMethod)) {
|
||||||
reranking_model_name: rerankDefaultModel?.model || '',
|
onChange({
|
||||||
},
|
...value,
|
||||||
reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore),
|
search_method: retrieveMethod,
|
||||||
weights: passValue.weights || {
|
...(!value.reranking_model.reranking_model_name
|
||||||
weight_type: WeightedScoreEnum.Customized,
|
? {
|
||||||
vector_setting: {
|
reranking_model: {
|
||||||
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
|
||||||
embedding_provider_name: '',
|
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
|
||||||
embedding_model_name: '',
|
},
|
||||||
},
|
reranking_enable: !!isRerankDefaultModelValid,
|
||||||
keyword_setting: {
|
}
|
||||||
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
: {
|
||||||
},
|
reranking_enable: true,
|
||||||
},
|
}),
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
return passValue
|
if (retrieveMethod === RETRIEVE_METHOD.hybrid) {
|
||||||
})()
|
onChange({
|
||||||
|
...value,
|
||||||
|
search_method: retrieveMethod,
|
||||||
|
...(!value.reranking_model.reranking_model_name
|
||||||
|
? {
|
||||||
|
reranking_model: {
|
||||||
|
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
|
||||||
|
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
|
||||||
|
},
|
||||||
|
reranking_enable: !!isRerankDefaultModelValid,
|
||||||
|
reranking_mode: isRerankDefaultModelValid ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore,
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
reranking_enable: true,
|
||||||
|
reranking_mode: RerankingModeEnum.RerankingModel,
|
||||||
|
}),
|
||||||
|
...(!value.weights
|
||||||
|
? {
|
||||||
|
weights: {
|
||||||
|
weight_type: WeightedScoreEnum.Customized,
|
||||||
|
vector_setting: {
|
||||||
|
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
||||||
|
embedding_provider_name: '',
|
||||||
|
embedding_model_name: '',
|
||||||
|
},
|
||||||
|
keyword_setting: {
|
||||||
|
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}, [value, rerankDefaultModel, isRerankDefaultModelValid, onChange])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='space-y-2'>
|
<div className='space-y-2'>
|
||||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
|
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
|
||||||
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
|
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
|
||||||
title={t('dataset.retrieval.semantic_search.title')}
|
title={t('dataset.retrieval.semantic_search.title')}
|
||||||
description={t('dataset.retrieval.semantic_search.description')}
|
description={t('dataset.retrieval.semantic_search.description')}
|
||||||
isActive={
|
isActive={
|
||||||
value.search_method === RETRIEVE_METHOD.semantic
|
value.search_method === RETRIEVE_METHOD.semantic
|
||||||
}
|
}
|
||||||
onSwitched={() => onChange({
|
onSwitched={() => onSwitch(RETRIEVE_METHOD.semantic)}
|
||||||
...value,
|
|
||||||
search_method: RETRIEVE_METHOD.semantic,
|
|
||||||
})}
|
|
||||||
effectImg={Effect.src}
|
effectImg={Effect.src}
|
||||||
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
||||||
>
|
>
|
||||||
@ -78,17 +111,14 @@ const RetrievalMethodConfig: FC<Props> = ({
|
|||||||
/>
|
/>
|
||||||
</OptionCard>
|
</OptionCard>
|
||||||
)}
|
)}
|
||||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
|
{supportRetrievalMethods.includes(RETRIEVE_METHOD.fullText) && (
|
||||||
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
|
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
|
||||||
title={t('dataset.retrieval.full_text_search.title')}
|
title={t('dataset.retrieval.full_text_search.title')}
|
||||||
description={t('dataset.retrieval.full_text_search.description')}
|
description={t('dataset.retrieval.full_text_search.description')}
|
||||||
isActive={
|
isActive={
|
||||||
value.search_method === RETRIEVE_METHOD.fullText
|
value.search_method === RETRIEVE_METHOD.fullText
|
||||||
}
|
}
|
||||||
onSwitched={() => onChange({
|
onSwitched={() => onSwitch(RETRIEVE_METHOD.fullText)}
|
||||||
...value,
|
|
||||||
search_method: RETRIEVE_METHOD.fullText,
|
|
||||||
})}
|
|
||||||
effectImg={Effect.src}
|
effectImg={Effect.src}
|
||||||
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
||||||
>
|
>
|
||||||
@ -99,8 +129,8 @@ const RetrievalMethodConfig: FC<Props> = ({
|
|||||||
/>
|
/>
|
||||||
</OptionCard>
|
</OptionCard>
|
||||||
)}
|
)}
|
||||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
|
{supportRetrievalMethods.includes(RETRIEVE_METHOD.hybrid) && (
|
||||||
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
|
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
|
||||||
title={
|
title={
|
||||||
<div className='flex items-center space-x-1'>
|
<div className='flex items-center space-x-1'>
|
||||||
<div>{t('dataset.retrieval.hybrid_search.title')}</div>
|
<div>{t('dataset.retrieval.hybrid_search.title')}</div>
|
||||||
@ -110,11 +140,7 @@ const RetrievalMethodConfig: FC<Props> = ({
|
|||||||
description={t('dataset.retrieval.hybrid_search.description')} isActive={
|
description={t('dataset.retrieval.hybrid_search.description')} isActive={
|
||||||
value.search_method === RETRIEVE_METHOD.hybrid
|
value.search_method === RETRIEVE_METHOD.hybrid
|
||||||
}
|
}
|
||||||
onSwitched={() => onChange({
|
onSwitched={() => onSwitch(RETRIEVE_METHOD.hybrid)}
|
||||||
...value,
|
|
||||||
search_method: RETRIEVE_METHOD.hybrid,
|
|
||||||
reranking_enable: true,
|
|
||||||
})}
|
|
||||||
effectImg={Effect.src}
|
effectImg={Effect.src}
|
||||||
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
activeHeaderClassName='bg-dataset-option-card-purple-gradient'
|
||||||
>
|
>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import React, { useCallback } from 'react'
|
import React, { useCallback, useMemo } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import Image from 'next/image'
|
import Image from 'next/image'
|
||||||
@ -39,8 +39,8 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid
|
const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid
|
||||||
const isEconomical = type === RETRIEVE_METHOD.invertedIndex
|
const isEconomical = type === RETRIEVE_METHOD.invertedIndex
|
||||||
|
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
|
||||||
const {
|
const {
|
||||||
defaultModel: rerankDefaultModel,
|
|
||||||
modelList: rerankModelList,
|
modelList: rerankModelList,
|
||||||
} = useModelListAndDefaultModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
@ -48,35 +48,28 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
currentModel,
|
currentModel,
|
||||||
} = useCurrentProviderAndModel(
|
} = useCurrentProviderAndModel(
|
||||||
rerankModelList,
|
rerankModelList,
|
||||||
rerankDefaultModel
|
{
|
||||||
? {
|
provider: value.reranking_model?.reranking_provider_name ?? '',
|
||||||
...rerankDefaultModel,
|
model: value.reranking_model?.reranking_model_name ?? '',
|
||||||
provider: rerankDefaultModel.provider.provider,
|
},
|
||||||
}
|
|
||||||
: undefined,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const handleDisabledSwitchClick = useCallback(() => {
|
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
|
||||||
if (!currentModel)
|
if (enable && !currentModel)
|
||||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
}, [currentModel, rerankDefaultModel, t])
|
onChange({
|
||||||
|
...value,
|
||||||
|
reranking_enable: enable,
|
||||||
|
})
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [currentModel, onChange, value])
|
||||||
|
|
||||||
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
|
const rerankModel = useMemo(() => {
|
||||||
|
return {
|
||||||
const rerankModel = (() => {
|
provider_name: value.reranking_model.reranking_provider_name,
|
||||||
if (value.reranking_model) {
|
model_name: value.reranking_model.reranking_model_name,
|
||||||
return {
|
|
||||||
provider_name: value.reranking_model.reranking_provider_name,
|
|
||||||
model_name: value.reranking_model.reranking_model_name,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if (rerankDefaultModel) {
|
}, [value.reranking_model])
|
||||||
return {
|
|
||||||
provider_name: rerankDefaultModel.provider.provider,
|
|
||||||
model_name: rerankDefaultModel.model,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})()
|
|
||||||
|
|
||||||
const handleChangeRerankMode = (v: RerankingModeEnum) => {
|
const handleChangeRerankMode = (v: RerankingModeEnum) => {
|
||||||
if (v === value.reranking_mode)
|
if (v === value.reranking_mode)
|
||||||
@ -100,6 +93,8 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (v === RerankingModeEnum.RerankingModel && !currentModel)
|
||||||
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
onChange(result)
|
onChange(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,22 +117,11 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
<div>
|
<div>
|
||||||
<div className='flex items-center space-x-2 mb-2'>
|
<div className='flex items-center space-x-2 mb-2'>
|
||||||
{canToggleRerankModalEnable && (
|
{canToggleRerankModalEnable && (
|
||||||
<div
|
<Switch
|
||||||
className='flex items-center'
|
size='md'
|
||||||
onClick={handleDisabledSwitchClick}
|
defaultValue={value.reranking_enable}
|
||||||
>
|
onChange={handleDisabledSwitchClick}
|
||||||
<Switch
|
/>
|
||||||
size='md'
|
|
||||||
defaultValue={currentModel ? value.reranking_enable : false}
|
|
||||||
onChange={(v) => {
|
|
||||||
onChange({
|
|
||||||
...value,
|
|
||||||
reranking_enable: v,
|
|
||||||
})
|
|
||||||
}}
|
|
||||||
disabled={!currentModel}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
)}
|
)}
|
||||||
<div className='flex items-center'>
|
<div className='flex items-center'>
|
||||||
<span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span>
|
<span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span>
|
||||||
@ -148,21 +132,23 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<ModelSelector
|
{
|
||||||
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
|
value.reranking_enable && (
|
||||||
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
|
<ModelSelector
|
||||||
modelList={rerankModelList}
|
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
|
||||||
readonly={!value.reranking_enable}
|
modelList={rerankModelList}
|
||||||
onSelect={(v) => {
|
onSelect={(v) => {
|
||||||
onChange({
|
onChange({
|
||||||
...value,
|
...value,
|
||||||
reranking_model: {
|
reranking_model: {
|
||||||
reranking_provider_name: v.provider,
|
reranking_provider_name: v.provider,
|
||||||
reranking_model_name: v.model,
|
reranking_model_name: v.model,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{
|
{
|
||||||
@ -255,10 +241,8 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||||||
{
|
{
|
||||||
value.reranking_mode !== RerankingModeEnum.WeightedScore && (
|
value.reranking_mode !== RerankingModeEnum.WeightedScore && (
|
||||||
<ModelSelector
|
<ModelSelector
|
||||||
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
|
|
||||||
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
|
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
|
||||||
modelList={rerankModelList}
|
modelList={rerankModelList}
|
||||||
readonly={!value.reranking_enable}
|
|
||||||
onSelect={(v) => {
|
onSelect={(v) => {
|
||||||
onChange({
|
onChange({
|
||||||
...value,
|
...value,
|
||||||
|
@ -30,6 +30,7 @@ import { useProviderContext } from '@/context/provider-context'
|
|||||||
import { sleep } from '@/utils'
|
import { sleep } from '@/utils'
|
||||||
import { RETRIEVE_METHOD } from '@/types/app'
|
import { RETRIEVE_METHOD } from '@/types/app'
|
||||||
import Tooltip from '@/app/components/base/tooltip'
|
import Tooltip from '@/app/components/base/tooltip'
|
||||||
|
import { useInvalidDocumentList } from '@/service/knowledge/use-document'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
datasetId: string
|
datasetId: string
|
||||||
@ -207,7 +208,9 @@ const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], index
|
|||||||
})
|
})
|
||||||
|
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
|
const invalidDocumentList = useInvalidDocumentList()
|
||||||
const navToDocumentList = () => {
|
const navToDocumentList = () => {
|
||||||
|
invalidDocumentList()
|
||||||
router.push(`/datasets/${datasetId}/documents`)
|
router.push(`/datasets/${datasetId}/documents`)
|
||||||
}
|
}
|
||||||
const navToApiDocs = () => {
|
const navToApiDocs = () => {
|
||||||
|
@ -31,17 +31,17 @@ import LanguageSelect from './language-select'
|
|||||||
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
|
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
|
import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
|
||||||
|
import { ChunkingMode, DataSourceType, ProcessMode } from '@/models/datasets'
|
||||||
|
|
||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import FloatRightContainer from '@/app/components/base/float-right-container'
|
import FloatRightContainer from '@/app/components/base/float-right-container'
|
||||||
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
||||||
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
||||||
import { type RetrievalConfig } from '@/types/app'
|
import { type RetrievalConfig } from '@/types/app'
|
||||||
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||||
import Toast from '@/app/components/base/toast'
|
import Toast from '@/app/components/base/toast'
|
||||||
import type { NotionPage } from '@/models/common'
|
import type { NotionPage } from '@/models/common'
|
||||||
import { DataSourceProvider } from '@/models/common'
|
import { DataSourceProvider } from '@/models/common'
|
||||||
import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets'
|
|
||||||
import { useDatasetDetailContext } from '@/context/dataset-detail'
|
import { useDatasetDetailContext } from '@/context/dataset-detail'
|
||||||
import I18n from '@/context/i18n'
|
import I18n from '@/context/i18n'
|
||||||
import { RETRIEVE_METHOD } from '@/types/app'
|
import { RETRIEVE_METHOD } from '@/types/app'
|
||||||
@ -53,7 +53,7 @@ import type { DefaultModel } from '@/app/components/header/account-setting/model
|
|||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import Checkbox from '@/app/components/base/checkbox'
|
import Checkbox from '@/app/components/base/checkbox'
|
||||||
import RadioCard from '@/app/components/base/radio-card'
|
import RadioCard from '@/app/components/base/radio-card'
|
||||||
import { IS_CE_EDITION } from '@/config'
|
import { FULL_DOC_PREVIEW_LENGTH, IS_CE_EDITION } from '@/config'
|
||||||
import Divider from '@/app/components/base/divider'
|
import Divider from '@/app/components/base/divider'
|
||||||
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset'
|
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset'
|
||||||
import Badge from '@/app/components/base/badge'
|
import Badge from '@/app/components/base/badge'
|
||||||
@ -90,17 +90,13 @@ type StepTwoProps = {
|
|||||||
onCancel?: () => void
|
onCancel?: () => void
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum SegmentType {
|
|
||||||
AUTO = 'automatic',
|
|
||||||
CUSTOM = 'custom',
|
|
||||||
}
|
|
||||||
export enum IndexingType {
|
export enum IndexingType {
|
||||||
QUALIFIED = 'high_quality',
|
QUALIFIED = 'high_quality',
|
||||||
ECONOMICAL = 'economy',
|
ECONOMICAL = 'economy',
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
|
const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
|
||||||
const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500
|
const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500
|
||||||
const DEFAULT_OVERLAP = 50
|
const DEFAULT_OVERLAP = 50
|
||||||
|
|
||||||
type ParentChildConfig = {
|
type ParentChildConfig = {
|
||||||
@ -131,7 +127,6 @@ const StepTwo = ({
|
|||||||
isSetting,
|
isSetting,
|
||||||
documentDetail,
|
documentDetail,
|
||||||
isAPIKeySet,
|
isAPIKeySet,
|
||||||
onSetting,
|
|
||||||
datasetId,
|
datasetId,
|
||||||
indexingType,
|
indexingType,
|
||||||
dataSourceType: inCreatePageDataSourceType,
|
dataSourceType: inCreatePageDataSourceType,
|
||||||
@ -162,12 +157,12 @@ const StepTwo = ({
|
|||||||
|
|
||||||
const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type)
|
const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type)
|
||||||
const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type
|
const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type
|
||||||
const [segmentationType, setSegmentationType] = useState<SegmentType>(SegmentType.CUSTOM)
|
const [segmentationType, setSegmentationType] = useState<ProcessMode>(ProcessMode.general)
|
||||||
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
|
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
|
||||||
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
|
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
|
||||||
doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER))
|
doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER))
|
||||||
}, [])
|
}, [])
|
||||||
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length
|
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH) // default chunk length
|
||||||
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000)
|
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000)
|
||||||
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
|
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
|
||||||
const [rules, setRules] = useState<PreProcessingRule[]>([])
|
const [rules, setRules] = useState<PreProcessingRule[]>([])
|
||||||
@ -198,7 +193,6 @@ const StepTwo = ({
|
|||||||
)
|
)
|
||||||
|
|
||||||
// QA Related
|
// QA Related
|
||||||
const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false)
|
|
||||||
const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false)
|
const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false)
|
||||||
const [docForm, setDocForm] = useState<ChunkingMode>(
|
const [docForm, setDocForm] = useState<ChunkingMode>(
|
||||||
(datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text,
|
(datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text,
|
||||||
@ -348,7 +342,7 @@ const StepTwo = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const updatePreview = () => {
|
const updatePreview = () => {
|
||||||
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) {
|
if (segmentationType === ProcessMode.general && maxChunkLength > 4000) {
|
||||||
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') })
|
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') })
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -373,13 +367,42 @@ const StepTwo = ({
|
|||||||
model: defaultEmbeddingModel?.model || '',
|
model: defaultEmbeddingModel?.model || '',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
|
||||||
|
search_method: RETRIEVE_METHOD.semantic,
|
||||||
|
reranking_enable: false,
|
||||||
|
reranking_model: {
|
||||||
|
reranking_provider_name: '',
|
||||||
|
reranking_model_name: '',
|
||||||
|
},
|
||||||
|
top_k: 3,
|
||||||
|
score_threshold_enabled: false,
|
||||||
|
score_threshold: 0.5,
|
||||||
|
} as RetrievalConfig)
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentDataset?.retrieval_model_dict)
|
||||||
|
return
|
||||||
|
setRetrievalConfig({
|
||||||
|
search_method: RETRIEVE_METHOD.semantic,
|
||||||
|
reranking_enable: !!isRerankDefaultModelValid,
|
||||||
|
reranking_model: {
|
||||||
|
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '',
|
||||||
|
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
|
||||||
|
},
|
||||||
|
top_k: 3,
|
||||||
|
score_threshold_enabled: false,
|
||||||
|
score_threshold: 0.5,
|
||||||
|
})
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [rerankDefaultModel, isRerankDefaultModelValid])
|
||||||
|
|
||||||
const getCreationParams = () => {
|
const getCreationParams = () => {
|
||||||
let params
|
let params
|
||||||
if (segmentationType === SegmentType.CUSTOM && overlap > maxChunkLength) {
|
if (segmentationType === ProcessMode.general && overlap > maxChunkLength) {
|
||||||
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') })
|
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') })
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) {
|
if (segmentationType === ProcessMode.general && maxChunkLength > limitMaxChunkLength) {
|
||||||
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) })
|
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) })
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -389,7 +412,6 @@ const StepTwo = ({
|
|||||||
doc_form: currentDocForm,
|
doc_form: currentDocForm,
|
||||||
doc_language: docLanguage,
|
doc_language: docLanguage,
|
||||||
process_rule: getProcessRule(),
|
process_rule: getProcessRule(),
|
||||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
|
||||||
retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page.
|
retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page.
|
||||||
embedding_model: embeddingModel.model, // Readonly
|
embedding_model: embeddingModel.model, // Readonly
|
||||||
embedding_model_provider: embeddingModel.provider, // Readonly
|
embedding_model_provider: embeddingModel.provider, // Readonly
|
||||||
@ -400,10 +422,7 @@ const StepTwo = ({
|
|||||||
const indexMethod = getIndexing_technique()
|
const indexMethod = getIndexing_technique()
|
||||||
if (
|
if (
|
||||||
!isReRankModelSelected({
|
!isReRankModelSelected({
|
||||||
rerankDefaultModel,
|
|
||||||
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
|
|
||||||
rerankModelList,
|
rerankModelList,
|
||||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
|
||||||
retrievalConfig,
|
retrievalConfig,
|
||||||
indexMethod: indexMethod as string,
|
indexMethod: indexMethod as string,
|
||||||
})
|
})
|
||||||
@ -411,16 +430,6 @@ const StepTwo = ({
|
|||||||
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
|
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
const postRetrievalConfig = ensureRerankModelSelected({
|
|
||||||
rerankDefaultModel: rerankDefaultModel!,
|
|
||||||
retrievalConfig: {
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
|
||||||
...retrievalConfig,
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
|
||||||
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
|
|
||||||
},
|
|
||||||
indexMethod: indexMethod as string,
|
|
||||||
})
|
|
||||||
params = {
|
params = {
|
||||||
data_source: {
|
data_source: {
|
||||||
type: dataSourceType,
|
type: dataSourceType,
|
||||||
@ -432,8 +441,7 @@ const StepTwo = ({
|
|||||||
process_rule: getProcessRule(),
|
process_rule: getProcessRule(),
|
||||||
doc_form: currentDocForm,
|
doc_form: currentDocForm,
|
||||||
doc_language: docLanguage,
|
doc_language: docLanguage,
|
||||||
|
retrieval_model: retrievalConfig,
|
||||||
retrieval_model: postRetrievalConfig,
|
|
||||||
embedding_model: embeddingModel.model,
|
embedding_model: embeddingModel.model,
|
||||||
embedding_model_provider: embeddingModel.provider,
|
embedding_model_provider: embeddingModel.provider,
|
||||||
} as CreateDocumentReq
|
} as CreateDocumentReq
|
||||||
@ -490,7 +498,6 @@ const StepTwo = ({
|
|||||||
|
|
||||||
const getDefaultMode = () => {
|
const getDefaultMode = () => {
|
||||||
if (documentDetail)
|
if (documentDetail)
|
||||||
// @ts-expect-error fix after api refactored
|
|
||||||
setSegmentationType(documentDetail.dataset_process_rule.mode)
|
setSegmentationType(documentDetail.dataset_process_rule.mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -525,7 +532,6 @@ const StepTwo = ({
|
|||||||
onSuccess(data) {
|
onSuccess(data) {
|
||||||
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
|
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
|
||||||
updateResultCache && updateResultCache(data)
|
updateResultCache && updateResultCache(data)
|
||||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
|
||||||
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
|
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -545,14 +551,6 @@ const StepTwo = ({
|
|||||||
isSetting && onSave && onSave()
|
isSetting && onSave && onSave()
|
||||||
}
|
}
|
||||||
|
|
||||||
const changeToEconomicalType = () => {
|
|
||||||
if (docForm !== ChunkingMode.text)
|
|
||||||
return
|
|
||||||
|
|
||||||
if (!hasSetIndexType)
|
|
||||||
setIndexType(IndexingType.ECONOMICAL)
|
|
||||||
}
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// fetch rules
|
// fetch rules
|
||||||
if (!isSetting) {
|
if (!isSetting) {
|
||||||
@ -574,18 +572,6 @@ const StepTwo = ({
|
|||||||
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
|
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
|
||||||
}, [isAPIKeySet, indexingType, datasetId])
|
}, [isAPIKeySet, indexingType, datasetId])
|
||||||
|
|
||||||
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
|
|
||||||
search_method: RETRIEVE_METHOD.semantic,
|
|
||||||
reranking_enable: false,
|
|
||||||
reranking_model: {
|
|
||||||
reranking_provider_name: rerankDefaultModel?.provider.provider,
|
|
||||||
reranking_model_name: rerankDefaultModel?.model,
|
|
||||||
},
|
|
||||||
top_k: 3,
|
|
||||||
score_threshold_enabled: false,
|
|
||||||
score_threshold: 0.5,
|
|
||||||
} as RetrievalConfig)
|
|
||||||
|
|
||||||
const economyDomRef = useRef<HTMLDivElement>(null)
|
const economyDomRef = useRef<HTMLDivElement>(null)
|
||||||
const isHoveringEconomy = useHover(economyDomRef)
|
const isHoveringEconomy = useHover(economyDomRef)
|
||||||
|
|
||||||
@ -946,6 +932,7 @@ const StepTwo = ({
|
|||||||
<div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div>
|
<div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div>
|
||||||
<ModelSelector
|
<ModelSelector
|
||||||
readonly={!!datasetId}
|
readonly={!!datasetId}
|
||||||
|
triggerClassName={datasetId ? 'opacity-50' : ''}
|
||||||
defaultModel={embeddingModel}
|
defaultModel={embeddingModel}
|
||||||
modelList={embeddingModelList}
|
modelList={embeddingModelList}
|
||||||
onSelect={(model: DefaultModel) => {
|
onSelect={(model: DefaultModel) => {
|
||||||
@ -984,12 +971,14 @@ const StepTwo = ({
|
|||||||
getIndexing_technique() === IndexingType.QUALIFIED
|
getIndexing_technique() === IndexingType.QUALIFIED
|
||||||
? (
|
? (
|
||||||
<RetrievalMethodConfig
|
<RetrievalMethodConfig
|
||||||
|
disabled={!!datasetId}
|
||||||
value={retrievalConfig}
|
value={retrievalConfig}
|
||||||
onChange={setRetrievalConfig}
|
onChange={setRetrievalConfig}
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
: (
|
: (
|
||||||
<EconomicalRetrievalMethodConfig
|
<EconomicalRetrievalMethodConfig
|
||||||
|
disabled={!!datasetId}
|
||||||
value={retrievalConfig}
|
value={retrievalConfig}
|
||||||
onChange={setRetrievalConfig}
|
onChange={setRetrievalConfig}
|
||||||
/>
|
/>
|
||||||
@ -1010,7 +999,7 @@ const StepTwo = ({
|
|||||||
)
|
)
|
||||||
: (
|
: (
|
||||||
<div className='flex items-center mt-8 py-2'>
|
<div className='flex items-center mt-8 py-2'>
|
||||||
<Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>
|
{!datasetId && <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>}
|
||||||
<Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button>
|
<Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@ -1081,11 +1070,11 @@ const StepTwo = ({
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
currentDocForm !== ChunkingMode.qa
|
currentDocForm !== ChunkingMode.qa
|
||||||
&& <Badge text={t(
|
&& <Badge text={t(
|
||||||
'datasetCreation.stepTwo.previewChunkCount', {
|
'datasetCreation.stepTwo.previewChunkCount', {
|
||||||
count: estimate?.total_segments || 0,
|
count: estimate?.total_segments || 0,
|
||||||
}) as string}
|
}) as string}
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
</div>
|
</div>
|
||||||
</PreviewHeader>}
|
</PreviewHeader>}
|
||||||
@ -1117,6 +1106,9 @@ const StepTwo = ({
|
|||||||
{currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && (
|
{currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && (
|
||||||
estimate?.preview?.map((item, index) => {
|
estimate?.preview?.map((item, index) => {
|
||||||
const indexForLabel = index + 1
|
const indexForLabel = index + 1
|
||||||
|
const childChunks = parentChildConfig.chunkForContext === 'full-doc'
|
||||||
|
? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH)
|
||||||
|
: item.child_chunks
|
||||||
return (
|
return (
|
||||||
<ChunkContainer
|
<ChunkContainer
|
||||||
key={item.content}
|
key={item.content}
|
||||||
@ -1124,7 +1116,7 @@ const StepTwo = ({
|
|||||||
characterCount={item.content.length}
|
characterCount={item.content.length}
|
||||||
>
|
>
|
||||||
<FormattedText>
|
<FormattedText>
|
||||||
{item.child_chunks.map((child, index) => {
|
{childChunks.map((child, index) => {
|
||||||
const indexForLabel = index + 1
|
const indexForLabel = index + 1
|
||||||
return (
|
return (
|
||||||
<PreviewSlice
|
<PreviewSlice
|
||||||
|
@ -4,7 +4,7 @@ import classNames from '@/utils/classnames'
|
|||||||
|
|
||||||
const TriangleArrow: FC<ComponentProps<'svg'>> = props => (
|
const TriangleArrow: FC<ComponentProps<'svg'>> = props => (
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}>
|
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}>
|
||||||
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor"/>
|
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor" />
|
||||||
</svg>
|
</svg>
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ export const OptionCard: FC<OptionCardProps> = forwardRef((props, ref) => {
|
|||||||
(isActive && !noHighlight)
|
(isActive && !noHighlight)
|
||||||
? 'border-[1.5px] border-components-option-card-option-selected-border'
|
? 'border-[1.5px] border-components-option-card-option-selected-border'
|
||||||
: 'border border-components-option-card-option-border',
|
: 'border border-components-option-card-option-border',
|
||||||
disabled && 'opacity-50 cursor-not-allowed',
|
disabled && 'opacity-50 pointer-events-none',
|
||||||
className,
|
className,
|
||||||
)}
|
)}
|
||||||
style={{
|
style={{
|
||||||
|
@ -32,6 +32,9 @@ import Checkbox from '@/app/components/base/checkbox'
|
|||||||
import {
|
import {
|
||||||
useChildSegmentList,
|
useChildSegmentList,
|
||||||
useChildSegmentListKey,
|
useChildSegmentListKey,
|
||||||
|
useChunkListAllKey,
|
||||||
|
useChunkListDisabledKey,
|
||||||
|
useChunkListEnabledKey,
|
||||||
useDeleteChildSegment,
|
useDeleteChildSegment,
|
||||||
useDeleteSegment,
|
useDeleteSegment,
|
||||||
useDisableSegment,
|
useDisableSegment,
|
||||||
@ -156,18 +159,18 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
page: isFullDocMode ? 1 : currentPage,
|
page: isFullDocMode ? 1 : currentPage,
|
||||||
limit: isFullDocMode ? 10 : limit,
|
limit: isFullDocMode ? 10 : limit,
|
||||||
keyword: isFullDocMode ? '' : searchValue,
|
keyword: isFullDocMode ? '' : searchValue,
|
||||||
enabled: selectedStatus === 'all' ? 'all' : !!selectedStatus,
|
enabled: selectedStatus,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
currentPage === 0,
|
|
||||||
)
|
)
|
||||||
const invalidSegmentList = useInvalid(useSegmentListKey)
|
const invalidSegmentList = useInvalid(useSegmentListKey)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (segmentListData) {
|
if (segmentListData) {
|
||||||
setSegments(segmentListData.data || [])
|
setSegments(segmentListData.data || [])
|
||||||
if (segmentListData.total_pages < currentPage)
|
const totalPages = segmentListData.total_pages
|
||||||
setCurrentPage(segmentListData.total_pages)
|
if (totalPages < currentPage)
|
||||||
|
setCurrentPage(totalPages === 0 ? 1 : totalPages)
|
||||||
}
|
}
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [segmentListData])
|
}, [segmentListData])
|
||||||
@ -185,12 +188,12 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
documentId,
|
documentId,
|
||||||
segmentId: segments[0]?.id || '',
|
segmentId: segments[0]?.id || '',
|
||||||
params: {
|
params: {
|
||||||
page: currentPage,
|
page: currentPage === 0 ? 1 : currentPage,
|
||||||
limit,
|
limit,
|
||||||
keyword: searchValue,
|
keyword: searchValue,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
!isFullDocMode || segments.length === 0 || currentPage === 0,
|
!isFullDocMode || segments.length === 0,
|
||||||
)
|
)
|
||||||
const invalidChildSegmentList = useInvalid(useChildSegmentListKey)
|
const invalidChildSegmentList = useInvalid(useChildSegmentListKey)
|
||||||
|
|
||||||
@ -204,21 +207,20 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (childChunkListData) {
|
if (childChunkListData) {
|
||||||
setChildSegments(childChunkListData.data || [])
|
setChildSegments(childChunkListData.data || [])
|
||||||
if (childChunkListData.total_pages < currentPage)
|
const totalPages = childChunkListData.total_pages
|
||||||
setCurrentPage(childChunkListData.total_pages)
|
if (totalPages < currentPage)
|
||||||
|
setCurrentPage(totalPages === 0 ? 1 : totalPages)
|
||||||
}
|
}
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [childChunkListData])
|
}, [childChunkListData])
|
||||||
|
|
||||||
const resetList = useCallback(() => {
|
const resetList = useCallback(() => {
|
||||||
setSegments([])
|
|
||||||
setSelectedSegmentIds([])
|
setSelectedSegmentIds([])
|
||||||
invalidSegmentList()
|
invalidSegmentList()
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const resetChildList = useCallback(() => {
|
const resetChildList = useCallback(() => {
|
||||||
setChildSegments([])
|
|
||||||
invalidChildSegmentList()
|
invalidChildSegmentList()
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [])
|
}, [])
|
||||||
@ -232,8 +234,32 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
setFullScreen(false)
|
setFullScreen(false)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
const onCloseNewSegmentModal = useCallback(() => {
|
||||||
|
onNewSegmentModalChange(false)
|
||||||
|
setFullScreen(false)
|
||||||
|
}, [onNewSegmentModalChange])
|
||||||
|
|
||||||
|
const onCloseNewChildChunkModal = useCallback(() => {
|
||||||
|
setShowNewChildSegmentModal(false)
|
||||||
|
setFullScreen(false)
|
||||||
|
}, [])
|
||||||
|
|
||||||
const { mutateAsync: enableSegment } = useEnableSegment()
|
const { mutateAsync: enableSegment } = useEnableSegment()
|
||||||
const { mutateAsync: disableSegment } = useDisableSegment()
|
const { mutateAsync: disableSegment } = useDisableSegment()
|
||||||
|
const invalidChunkListAll = useInvalid(useChunkListAllKey)
|
||||||
|
const invalidChunkListEnabled = useInvalid(useChunkListEnabledKey)
|
||||||
|
const invalidChunkListDisabled = useInvalid(useChunkListDisabledKey)
|
||||||
|
|
||||||
|
const refreshChunkListWithStatusChanged = () => {
|
||||||
|
switch (selectedStatus) {
|
||||||
|
case 'all':
|
||||||
|
invalidChunkListDisabled()
|
||||||
|
invalidChunkListEnabled()
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
invalidSegmentList()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => {
|
const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => {
|
||||||
const operationApi = enable ? enableSegment : disableSegment
|
const operationApi = enable ? enableSegment : disableSegment
|
||||||
@ -245,6 +271,7 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
seg.enabled = enable
|
seg.enabled = enable
|
||||||
}
|
}
|
||||||
setSegments([...segments])
|
setSegments([...segments])
|
||||||
|
refreshChunkListWithStatusChanged()
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||||
@ -271,6 +298,23 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
|
|
||||||
const { mutateAsync: updateSegment } = useUpdateSegment()
|
const { mutateAsync: updateSegment } = useUpdateSegment()
|
||||||
|
|
||||||
|
const refreshChunkListDataWithDetailChanged = () => {
|
||||||
|
switch (selectedStatus) {
|
||||||
|
case 'all':
|
||||||
|
invalidChunkListDisabled()
|
||||||
|
invalidChunkListEnabled()
|
||||||
|
break
|
||||||
|
case true:
|
||||||
|
invalidChunkListAll()
|
||||||
|
invalidChunkListDisabled()
|
||||||
|
break
|
||||||
|
case false:
|
||||||
|
invalidChunkListAll()
|
||||||
|
invalidChunkListEnabled()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const handleUpdateSegment = useCallback(async (
|
const handleUpdateSegment = useCallback(async (
|
||||||
segmentId: string,
|
segmentId: string,
|
||||||
question: string,
|
question: string,
|
||||||
@ -320,6 +364,7 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
setSegments([...segments])
|
setSegments([...segments])
|
||||||
|
refreshChunkListDataWithDetailChanged()
|
||||||
eventEmitter?.emit('update-segment-success')
|
eventEmitter?.emit('update-segment-success')
|
||||||
},
|
},
|
||||||
onSettled() {
|
onSettled() {
|
||||||
@ -432,6 +477,7 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
seg.child_chunks?.push(newChildChunk!)
|
seg.child_chunks?.push(newChildChunk!)
|
||||||
}
|
}
|
||||||
setSegments([...segments])
|
setSegments([...segments])
|
||||||
|
refreshChunkListDataWithDetailChanged()
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
resetChildList()
|
resetChildList()
|
||||||
@ -496,17 +542,10 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
setSegments([...segments])
|
setSegments([...segments])
|
||||||
|
refreshChunkListDataWithDetailChanged()
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (const childSeg of childSegments) {
|
resetChildList()
|
||||||
if (childSeg.id === childChunkId) {
|
|
||||||
childSeg.content = res.data.content
|
|
||||||
childSeg.type = res.data.type
|
|
||||||
childSeg.word_count = res.data.word_count
|
|
||||||
childSeg.updated_at = res.data.updated_at
|
|
||||||
}
|
|
||||||
}
|
|
||||||
setChildSegments([...childSegments])
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onSettled: () => {
|
onSettled: () => {
|
||||||
@ -544,12 +583,13 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
<SimpleSelect
|
<SimpleSelect
|
||||||
onSelect={onChangeStatus}
|
onSelect={onChangeStatus}
|
||||||
items={statusList.current}
|
items={statusList.current}
|
||||||
defaultValue={'all'}
|
defaultValue={selectedStatus === 'all' ? 'all' : selectedStatus ? 1 : 0}
|
||||||
className={s.select}
|
className={s.select}
|
||||||
wrapperClassName='h-fit mr-2'
|
wrapperClassName='h-fit mr-2'
|
||||||
optionWrapClassName='w-[160px]'
|
optionWrapClassName='w-[160px]'
|
||||||
optionClassName='p-0'
|
optionClassName='p-0'
|
||||||
renderOption={({ item, selected }) => <StatusItem item={item} selected={selected} />}
|
renderOption={({ item, selected }) => <StatusItem item={item} selected={selected} />}
|
||||||
|
notClearable
|
||||||
/>
|
/>
|
||||||
<Input
|
<Input
|
||||||
showLeftIcon
|
showLeftIcon
|
||||||
@ -623,6 +663,7 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
<FullScreenDrawer
|
<FullScreenDrawer
|
||||||
isOpen={currSegment.showModal}
|
isOpen={currSegment.showModal}
|
||||||
fullScreen={fullScreen}
|
fullScreen={fullScreen}
|
||||||
|
onClose={onCloseSegmentDetail}
|
||||||
>
|
>
|
||||||
<SegmentDetail
|
<SegmentDetail
|
||||||
segInfo={currSegment.segInfo ?? { id: '' }}
|
segInfo={currSegment.segInfo ?? { id: '' }}
|
||||||
@ -636,13 +677,11 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
<FullScreenDrawer
|
<FullScreenDrawer
|
||||||
isOpen={showNewSegmentModal}
|
isOpen={showNewSegmentModal}
|
||||||
fullScreen={fullScreen}
|
fullScreen={fullScreen}
|
||||||
|
onClose={onCloseNewSegmentModal}
|
||||||
>
|
>
|
||||||
<NewSegment
|
<NewSegment
|
||||||
docForm={docForm}
|
docForm={docForm}
|
||||||
onCancel={() => {
|
onCancel={onCloseNewSegmentModal}
|
||||||
onNewSegmentModalChange(false)
|
|
||||||
setFullScreen(false)
|
|
||||||
}}
|
|
||||||
onSave={resetList}
|
onSave={resetList}
|
||||||
viewNewlyAddedChunk={viewNewlyAddedChunk}
|
viewNewlyAddedChunk={viewNewlyAddedChunk}
|
||||||
/>
|
/>
|
||||||
@ -651,6 +690,7 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
<FullScreenDrawer
|
<FullScreenDrawer
|
||||||
isOpen={currChildChunk.showModal}
|
isOpen={currChildChunk.showModal}
|
||||||
fullScreen={fullScreen}
|
fullScreen={fullScreen}
|
||||||
|
onClose={onCloseChildSegmentDetail}
|
||||||
>
|
>
|
||||||
<ChildSegmentDetail
|
<ChildSegmentDetail
|
||||||
chunkId={currChunkId}
|
chunkId={currChunkId}
|
||||||
@ -664,13 +704,11 @@ const Completed: FC<ICompletedProps> = ({
|
|||||||
<FullScreenDrawer
|
<FullScreenDrawer
|
||||||
isOpen={showNewChildSegmentModal}
|
isOpen={showNewChildSegmentModal}
|
||||||
fullScreen={fullScreen}
|
fullScreen={fullScreen}
|
||||||
|
onClose={onCloseNewChildChunkModal}
|
||||||
>
|
>
|
||||||
<NewChildSegment
|
<NewChildSegment
|
||||||
chunkId={currChunkId}
|
chunkId={currChunkId}
|
||||||
onCancel={() => {
|
onCancel={onCloseNewChildChunkModal}
|
||||||
setShowNewChildSegmentModal(false)
|
|
||||||
setFullScreen(false)
|
|
||||||
}}
|
|
||||||
onSave={onSaveNewChildChunk}
|
onSave={onSaveNewChildChunk}
|
||||||
viewNewlyAddedChildChunk={viewNewlyAddedChildChunk}
|
viewNewlyAddedChildChunk={viewNewlyAddedChildChunk}
|
||||||
/>
|
/>
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user