merge main

This commit is contained in:
Joel 2025-05-26 14:39:19 +08:00
commit 94d0ba5dd6
397 changed files with 11292 additions and 5586 deletions

View File

@ -1,5 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:3.12
# [Optional] Uncomment this section to install additional OS packages.
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
# && apt-get -y install --no-install-recommends <your-package-list-here>
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev

View File

@ -139,6 +139,7 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files

View File

@ -269,6 +269,7 @@ OPENSEARCH_PORT=9200
OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true
OPENSEARCH_VERIFY_CERTS=true
# Baidu configuration
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
@ -348,6 +349,7 @@ SENTRY_DSN=
# DEBUG
DEBUG=false
ENABLE_REQUEST_LOGGING=False
SQLALCHEMY_ECHO=false
# Notion import configuration, support public and internal

View File

@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp):
ext_otel,
ext_proxy_fix,
ext_redis,
ext_request_logging,
ext_sentry,
ext_set_secretkey,
ext_storage,
@ -83,6 +84,7 @@ def initialize_extensions(app: DifyApp):
ext_blueprints,
ext_commands,
ext_otel,
ext_request_logging,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -17,6 +17,12 @@ class DeploymentConfig(BaseSettings):
default=False,
)
# Request logging configuration
ENABLE_REQUEST_LOGGING: bool = Field(
description="Enable request and response body logging",
default=False,
)
EDITION: str = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED",

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional
from pydantic import Field
from pydantic_settings import BaseSettings
@ -34,7 +34,7 @@ class S3StorageConfig(BaseSettings):
default=None,
)
S3_ADDRESS_STYLE: str = Field(
S3_ADDRESS_STYLE: Literal["auto", "virtual", "path"] = Field(
description="S3 addressing style: 'auto', 'path', or 'virtual'",
default="auto",
)

View File

@ -33,6 +33,11 @@ class OpenSearchConfig(BaseSettings):
default=False,
)
OPENSEARCH_VERIFY_CERTS: bool = Field(
description="Whether to verify SSL certificates for HTTPS connections (recommended to set True in production)",
default=True,
)
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
description="Authentication method for OpenSearch connection (default is 'basic')",
default=AuthMethod.BASIC,

View File

@ -17,15 +17,13 @@ from controllers.console.wraps import (
)
from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
app_pagination_fields,
)
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required
from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
@ -75,7 +73,17 @@ class AppListApi(Resource):
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
return marshal(app_pagination, app_pagination_fields)
if FeatureService.get_system_features().webapp_auth.enabled:
app_ids = [str(app.id) for app in app_pagination.items]
res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids)
if len(res) != len(app_ids):
raise BadRequest("Invalid app id in webapp auth")
for app in app_pagination.items:
if str(app.id) in res:
app.access_mode = res[str(app.id)].access_mode
return marshal(app_pagination, app_pagination_fields), 200
@setup_required
@login_required
@ -119,6 +127,10 @@ class AppApi(Resource):
app_model = app_service.get_app(app_model)
if FeatureService.get_system_features().webapp_auth.enabled:
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
return app_model
@setup_required

View File

@ -81,8 +81,7 @@ class DraftWorkflowApi(Resource):
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json")
# TODO: set this to required=True after frontend is updated
parser.add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
args = parser.parse_args()
elif "text/plain" in content_type:

View File

@ -1,3 +1,6 @@
from typing import cast
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
@ -12,8 +15,7 @@ from fields.workflow_run_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from models import App
from models.model import AppMode
from models import Account, App, AppMode, EndUser
from services.workflow_run_service import WorkflowRunService
@ -90,7 +92,12 @@ class WorkflowRunNodeExecutionListApi(Resource):
run_id = str(run_id)
workflow_run_service = WorkflowRunService()
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
user = cast("Account | EndUser", current_user)
node_executions = workflow_run_service.get_workflow_run_node_executions(
app_model=app_model,
run_id=run_id,
user=user,
)
return {"data": node_executions}

View File

@ -24,7 +24,7 @@ from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
@ -119,6 +119,9 @@ class ForgotPasswordResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
@ -168,6 +171,8 @@ class ForgotPasswordResetApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
pass
except WorkspacesLimitExceededError:
pass
except AccountRegisterError:
raise AccountInFreezeError()

View File

@ -21,6 +21,7 @@ from controllers.console.error import (
AccountNotFound,
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
@ -30,7 +31,7 @@ from models.account import Account
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
@ -88,10 +89,15 @@ class LoginApi(Resource):
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
return {
"result": "fail",
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
system_features = FeatureService.get_system_features()
if system_features.is_allow_create_workspace and not system_features.license.workspaces.is_available():
raise WorkspacesLimitExceeded()
else:
return {
"result": "fail",
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
@ -196,15 +202,18 @@ class EmailCodeLoginApi(Resource):
except AccountRegisterError as are:
raise AccountInFreezeError()
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
tenants = TenantService.get_join_tenants(account)
if not tenants:
workspaces = FeatureService.get_system_features().license.workspaces
if not workspaces.is_available():
raise WorkspacesLimitExceeded()
if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(new_tenant, account, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)
if account is None:
try:
@ -215,6 +224,8 @@ class EmailCodeLoginApi(Resource):
return NotAllowedCreateWorkspace()
except AccountRegisterError as are:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -148,15 +148,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
account = _get_account_by_openid_or_email(provider, user_info)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
tenants = TenantService.get_join_tenants(account)
if not tenants:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(new_tenant, account, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)
if not account:
if not FeatureService.get_system_features().is_allow_register:

View File

@ -540,9 +540,22 @@ class DatasetIndexingStatusApi(Resource):
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields))
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data

View File

@ -583,11 +583,22 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = "paused"
documents_status.append(marshal(document, document_status_fields))
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data
@ -616,11 +627,22 @@ class DocumentIndexingStatusApi(DocumentResource):
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = "paused"
return marshal(document, document_status_fields)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
return marshal(document_dict, document_status_fields)
class DocumentDetailApi(DocumentResource):

View File

@ -46,6 +46,18 @@ class NotAllowedCreateWorkspace(BaseHTTPException):
code = 400
class WorkspaceMembersLimitExceeded(BaseHTTPException):
error_code = "limit_exceeded"
description = "Unable to add member because the maximum workspace's member limit was exceeded"
code = 400
class WorkspacesLimitExceeded(BaseHTTPException):
error_code = "limit_exceeded"
description = "Unable to create workspace because the maximum workspace limit was exceeded"
code = 400
class AccountBannedError(BaseHTTPException):
error_code = "account_banned"
description = "Account is banned."

View File

@ -23,3 +23,9 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
error_code = "app_suggested_questions_after_answer_disabled"
description = "Function Suggested questions after answer disabled."
code = 403
class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403

View File

@ -1,3 +1,4 @@
import logging
from datetime import UTC, datetime
from typing import Any
@ -15,6 +16,11 @@ from fields.installed_app_fields import installed_app_list_fields
from libs.login import login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
class InstalledAppsListApi(Resource):
@ -48,6 +54,21 @@ class InstalledAppsListApi(Resource):
for installed_app in installed_apps
if installed_app.app is not None
]
# filter out apps that user doesn't have access to
if FeatureService.get_system_features().webapp_auth.enabled:
user_id = current_user.id
res = []
for installed_app in installed_app_list:
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_code=app_code,
):
res.append(installed_app)
installed_app_list = res
logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}")
installed_app_list.sort(
key=lambda app: (
-app["is_pinned"],

View File

@ -4,10 +4,14 @@ from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from models import InstalledApp
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
def installed_app_required(view=None):
@ -48,6 +52,36 @@ def installed_app_required(view=None):
return decorator
def user_allowed_to_access_app(view=None):
def decorator(view):
@wraps(view)
def decorated(installed_app: InstalledApp, *args, **kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=str(current_user.id),
app_code=app_code,
)
if not res:
raise AppAccessDeniedError()
return view(installed_app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [installed_app_required, account_initialization_required, login_required]
method_decorators = [
user_allowed_to_access_app,
installed_app_required,
account_initialization_required,
login_required,
]

View File

@ -6,6 +6,7 @@ from flask_restful import Resource, abort, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import api
from controllers.console.error import WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@ -17,6 +18,7 @@ from libs.login import login_required
from models.account import Account, TenantAccountRole
from services.account_service import RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
class MemberListApi(Resource):
@ -54,6 +56,12 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
if not workspace_members.is_available(len(invitee_emails)):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
try:
token = RegisterService.invite_new_member(

View File

@ -68,16 +68,24 @@ class TenantListApi(Resource):
@account_initialization_required
def get(self):
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
for tenant in tenants:
features = FeatureService.get_features(tenant.id)
if features.billing.enabled:
tenant.plan = features.billing.subscription.plan
else:
tenant.plan = "sandbox"
if tenant.id == current_user.current_tenant_id:
tenant.current = True # Set current=True for current tenant
return {"workspaces": marshal(tenants, tenants_fields)}, 200
# Create a dictionary with tenant attributes
tenant_dict = {
"id": tenant.id,
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id,
}
tenant_dicts.append(tenant_dict)
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
class WorkspaceListApi(Resource):

View File

@ -64,9 +64,24 @@ class PluginUploadFileApi(Resource):
extension = guess_extension(tool_file.mimetype) or ".bin"
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
tool_file.mime_type = mimetype
tool_file.extension = extension
tool_file.preview_url = preview_url
# Create a dictionary with all the necessary attributes
result = {
"id": tool_file.id,
"user_id": tool_file.user_id,
"tenant_id": tool_file.tenant_id,
"conversation_id": tool_file.conversation_id,
"file_key": tool_file.file_key,
"mimetype": tool_file.mimetype,
"original_url": tool_file.original_url,
"name": tool_file.name,
"size": tool_file.size,
"mime_type": mimetype,
"extension": extension,
"preview_url": preview_url,
}
return result, 201
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:

View File

@ -5,5 +5,6 @@ from libs.external_api import ExternalApi
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
api = ExternalApi(bp)
from . import mail
from .plugin import plugin
from .workspace import workspace

View File

@ -0,0 +1,27 @@
from flask_restful import (
Resource, # type: ignore
reqparse,
)
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import enterprise_inner_api_only
from services.enterprise.mail_service import DifyMail, EnterpriseMailService
class EnterpriseMail(Resource):
@setup_required
@enterprise_inner_api_only
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("to", type=str, action="append", required=True)
parser.add_argument("subject", type=str, required=True)
parser.add_argument("body", type=str, required=True)
parser.add_argument("substitutions", type=dict, required=False)
args = parser.parse_args()
EnterpriseMailService.send_mail(DifyMail(**args))
return {"message": "success"}, 200
api.add_resource(EnterpriseMail, "/enterprise/mail")

View File

@ -388,11 +388,22 @@ class DocumentIndexingStatusApi(DatasetApiResource):
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = "paused"
documents_status.append(marshal(document, document_status_fields))
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data

View File

@ -1,12 +1,15 @@
from flask_restful import marshal_with
from flask import request
from flask_restful import Resource, marshal_with, reqparse
from controllers.common import fields
from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService
from models.model import App, AppMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
class AppParameterApi(WebApiResource):
@ -40,5 +43,51 @@ class AppMeta(WebApiResource):
return AppService().get_app_meta(app_model)
class AppAccessMode(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=True, location="args")
args = parser.parse_args()
app_id = args["appId"]
res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
return {"accessMode": res.access_mode}
class AppWebAuthPermission(Resource):
def get(self):
user_id = "visitor"
try:
auth_header = request.headers.get("Authorization")
if auth_header is None:
raise
if " " not in auth_header:
raise
auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise
decoded = PassportService().verify(tk)
user_id = decoded.get("user_id", "visitor")
except Exception as e:
pass
parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=True, location="args")
args = parser.parse_args()
app_id = args["appId"]
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
return {"result": res}
api.add_resource(AppParameterApi, "/parameters")
api.add_resource(AppMeta, "/meta")
# webapp auth apis
api.add_resource(AppAccessMode, "/webapp/access-mode")
api.add_resource(AppWebAuthPermission, "/webapp/permission")

View File

@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
class WebSSOAuthRequiredError(BaseHTTPException):
class WebAppAuthRequiredError(BaseHTTPException):
error_code = "web_sso_auth_required"
description = "Web SSO authentication required."
description = "Web app authentication required."
code = 401
class WebAppAuthAccessDeniedError(BaseHTTPException):
error_code = "web_app_access_denied"
description = "You do not have permission to access this web app."
code = 401

View File

@ -0,0 +1,120 @@
from flask import request
from flask_restful import Resource, reqparse
from jwt import InvalidTokenError # type: ignore
from werkzeug.exceptions import BadRequest
import services
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
from controllers.console.error import AccountBannedError, AccountNotFound
from controllers.console.wraps import setup_required
from libs.helper import email
from libs.password import valid_password
from services.account_service import AccountService
from services.webapp_auth_service import WebAppAuthService
class LoginApi(Resource):
"""Resource for web app email/password login."""
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise BadRequest("X-App-Code header is missing.")
try:
account = WebAppAuthService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
raise AccountNotFound()
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
return {"result": "success", "token": token}
# class LogoutApi(Resource):
# @setup_required
# def get(self):
# account = cast(Account, flask_login.current_user)
# if isinstance(account, flask_login.AnonymousUserMixin):
# return {"result": "success"}
# flask_login.logout_user()
# return {"result": "success"}
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = WebAppAuthService.get_user_through_email(args["email"])
if account is None:
raise AccountNotFound()
else:
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
user_email = args["email"]
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise BadRequest("X-App-Code header is missing.")
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(user_email)
if not account:
raise AccountNotFound()
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "token": token}
# api.add_resource(LoginApi, "/login")
# api.add_resource(LogoutApi, "/logout")
# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

View File

@ -5,7 +5,7 @@ from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web import api
from controllers.web.error import WebSSOAuthRequiredError
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from models.model import App, EndUser, Site
@ -24,10 +24,10 @@ class PassportResource(Resource):
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
raise WebSSOAuthRequiredError()
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if not app_settings or not app_settings.access_mode == "public":
raise WebAppAuthRequiredError()
# get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebSSOAuthRequiredError
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from models.model import App, EndUser, Site
@ -29,7 +29,7 @@ def validate_jwt_token(view=None):
def decode_jwt_token():
system_features = FeatureService.get_system_features()
app_code = request.headers.get("X-App-Code")
app_code = str(request.headers.get("X-App-Code"))
try:
auth_header = request.headers.get("Authorization")
if auth_header is None:
@ -57,35 +57,53 @@ def decode_jwt_token():
if not end_user:
raise NotFound()
_validate_web_sso_token(decoded, system_features, app_code)
# for enterprise webapp auth
app_web_auth_enabled = False
if system_features.webapp_auth.enabled:
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
)
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
return app_model, end_user
except Unauthorized as e:
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
raise WebSSOAuthRequiredError()
if system_features.webapp_auth.enabled:
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
)
if app_web_auth_enabled:
raise WebAppAuthRequiredError()
raise Unauthorized(e.description)
def _validate_web_sso_token(decoded, system_features, app_code):
app_web_sso_enabled = False
# Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
source = decoded.get("token_source")
if not source or source != "sso":
raise WebSSOAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO,
# raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
# Check if authentication is enforced for web app, and if the token source is not webapp,
# raise an error and redirect to login
if system_webapp_auth_enabled and app_web_auth_enabled:
source = decoded.get("token_source")
if source and source == "sso":
raise Unauthorized("sso token expired.")
if not source or source != "webapp":
raise WebAppAuthRequiredError()
# Check if authentication is not enforced for web, and if the token source is webapp,
# raise an error and redirect to normal passport login
if not system_webapp_auth_enabled or not app_web_auth_enabled:
source = decoded.get("token_source")
if source and source == "webapp":
raise Unauthorized("webapp token expired.")
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
if system_webapp_auth_enabled and app_web_auth_enabled:
# Check if the user is allowed to access the web app
user_id = decoded.get("user_id")
if not user_id:
raise WebAppAuthRequiredError()
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
raise WebAppAuthAccessDeniedError()
class WebApiResource(Resource):

View File

@ -109,6 +109,7 @@ class VariableEntity(BaseModel):
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
max_length: Optional[int] = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)

View File

@ -26,12 +26,13 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
@ -161,12 +162,27 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
@ -174,6 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
@ -227,12 +244,23 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
return self._generate(
@ -240,6 +268,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
@ -291,12 +320,23 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
return self._generate(
@ -304,6 +344,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
@ -316,6 +357,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
stream: bool = True,
@ -380,6 +422,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
)
@ -452,6 +495,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@ -475,9 +519,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream,
dialogue_count=self._dialogue_count,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
)
try:

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
@ -64,13 +65,14 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.workflow import (
Workflow,
WorkflowRunStatus,
@ -94,6 +96,7 @@ class AdvancedChatAppGenerateTaskPipeline:
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
@ -105,11 +108,11 @@ class AdvancedChatAppGenerateTaskPipeline:
if isinstance(user, EndUser):
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
@ -125,9 +128,14 @@ class AdvancedChatAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
)
self._task_state = WorkflowTaskState()
self._message_cycle_manager = MessageCycleManage(
application_generate_entity=application_generate_entity, task_state=self._task_state
@ -294,21 +302,19 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
self._workflow_run_id = workflow_execution.id
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
message.workflow_run_id = workflow_execution.id
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_start_resp
elif isinstance(
@ -319,13 +325,10 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -338,20 +341,15 @@ class AdvancedChatAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
workflow_run=workflow_run, event=event
)
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_resp:
yield node_start_resp
@ -359,15 +357,15 @@ class AdvancedChatAppGenerateTaskPipeline:
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
event=event
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -383,11 +381,11 @@ class AdvancedChatAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -399,132 +397,92 @@ class AdvancedChatAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
parallel_start_resp = (
self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_start_resp
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
parallel_finish_resp = (
self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_finish_resp
elif isinstance(event, QueueIterationStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_start_resp
elif isinstance(event, QueueIterationNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_next_resp
elif isinstance(event, QueueIterationCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_finish_resp
elif isinstance(event, QueueLoopStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_start_resp
elif isinstance(event, QueueLoopNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
elif isinstance(event, QueueLoopCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_finish_resp
elif isinstance(event, QueueWorkflowSucceededEvent):
@ -535,10 +493,8 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -546,10 +502,11 @@ class AdvancedChatAppGenerateTaskPipeline:
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
self._base_task_pipeline._queue_manager.publish(
@ -562,10 +519,8 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -573,10 +528,11 @@ class AdvancedChatAppGenerateTaskPipeline:
conversation_id=None,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
self._base_task_pipeline._queue_manager.publish(
@ -589,26 +545,25 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
error_message=event.error,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit()
yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err)
@ -616,21 +571,19 @@ class AdvancedChatAppGenerateTaskPipeline:
elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
error_message=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
workflow_execution=workflow_execution,
)
# Save message
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
@ -711,7 +664,7 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
yield self._workflow_response_converter.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
@ -739,9 +692,9 @@ class AdvancedChatAppGenerateTaskPipeline:
url=file["remote_url"],
belongs_to="assistant",
upload_file_id=file["related_id"],
created_by_role=CreatedByRole.ACCOUNT
created_by_role=CreatorUserRole.ACCOUNT
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER,
else CreatorUserRole.END_USER,
created_by=message.from_account_id or message.from_end_user_id or "",
)
for file in self._recorded_files

View File

View File

@ -0,0 +1,564 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
LoopNodeCompletedStreamResponse,
LoopNodeNextStreamResponse,
LoopNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_execution_entities import NodeExecution
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowNodeExecutionStatus,
WorkflowRun,
)
class WorkflowResponseConverter:
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
) -> None:
self._application_generate_entity = application_generate_entity
def workflow_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution: WorkflowExecution,
) -> WorkflowStartStreamResponse:
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id,
data=WorkflowStartStreamResponse.Data(
id=workflow_execution.id,
workflow_id=workflow_execution.workflow_id,
sequence_number=workflow_execution.sequence_number,
inputs=workflow_execution.inputs,
created_at=int(workflow_execution.started_at.timestamp()),
),
)
def workflow_finish_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_execution: WorkflowExecution,
) -> WorkflowFinishStreamResponse:
created_by = None
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
assert workflow_run is not None
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt)
if account:
created_by = {
"id": account.id,
"name": account.name,
"email": account.email,
}
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
end_user = session.scalar(stmt)
if end_user:
created_by = {
"id": end_user.id,
"user": end_user.session_id,
}
else:
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (
int(workflow_execution.finished_at.timestamp())
if workflow_execution.finished_at
else int(datetime.now(UTC).timestamp())
)
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id,
data=WorkflowFinishStreamResponse.Data(
id=workflow_execution.id,
workflow_id=workflow_execution.workflow_id,
sequence_number=workflow_execution.sequence_number,
status=workflow_execution.status,
outputs=workflow_execution.outputs,
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
total_steps=workflow_execution.total_steps,
created_by=created_by,
created_at=int(workflow_execution.started_at.timestamp()),
finished_at=finished_at_timestamp,
files=self.fetch_files_from_node_outputs(workflow_execution.outputs),
exceptions_count=workflow_execution.exceptions_count,
),
)
def workflow_node_start_to_stream_response(
self,
*,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: NodeExecution,
) -> Optional[NodeStartStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_run_id:
return None
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeStartStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
parallel_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
),
)
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
)
return response
def workflow_node_finish_to_stream_response(
self,
*,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: NodeExecution,
) -> Optional[NodeFinishStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_run_id:
return None
if not workflow_node_execution.finished_at:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeFinishStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
)
def workflow_node_retry_to_stream_response(
self,
*,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: NodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_run_id:
return None
if not workflow_node_execution.finished_at:
return None
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
retry_index=event.retry_index,
),
)
def workflow_parallel_branch_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunStartedEvent,
) -> ParallelBranchStartStreamResponse:
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=int(time.time()),
),
)
def workflow_parallel_branch_finished_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
def workflow_iteration_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationStartEvent,
) -> IterationNodeStartStreamResponse:
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def workflow_iteration_next_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationNextEvent,
) -> IterationNodeNextStreamResponse:
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
def workflow_iteration_completed_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def workflow_loop_start_to_stream_response(
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
return LoopNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=LoopNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def workflow_loop_next_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueLoopNextEvent,
) -> LoopNodeNextStreamResponse:
return LoopNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=LoopNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_loop_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
def workflow_loop_completed_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueLoopCompletedEvent,
) -> LoopNodeCompletedStreamResponse:
return LoopNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=LoopNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
:return:
"""
if not outputs_dict:
return []
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list
# Flatten the list of sequences into a single list of mappings
flattened_files = [file for sublist in files if sublist for file in sublist]
# Convert to tuple to match Sequence type
return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
:return:
"""
if not value:
return []
files = []
if isinstance(value, list):
for item in value:
file = self._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(value, dict):
file = self._get_file_var_from_value(value)
if file:
files.append(file)
return files
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
:return:
"""
if not value:
return None
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
return None
def handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
node_id=event.node_id,
),
)

View File

@ -25,7 +25,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from models import Account
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
@ -223,7 +223,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER),
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)

View File

@ -18,16 +18,19 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@ -136,12 +139,27 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
@ -150,6 +168,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
@ -163,6 +182,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
@ -207,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)
@ -260,12 +281,25 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
return self._generate(
@ -274,6 +308,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@ -323,12 +358,25 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
return self._generate(
@ -337,6 +385,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@ -394,6 +443,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -413,8 +463,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
)
try:

View File

@ -3,10 +3,12 @@ import time
from collections.abc import Generator
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
@ -53,12 +55,14 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.model import EndUser
from models.workflow import (
Workflow,
@ -83,6 +87,7 @@ class WorkflowAppGenerateTaskPipeline:
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
@ -94,11 +99,11 @@ class WorkflowAppGenerateTaskPipeline:
if isinstance(user, EndUser):
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise ValueError(f"Invalid user type: {type(user)}")
@ -111,9 +116,14 @@ class WorkflowAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
@ -258,17 +268,15 @@ class WorkflowAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
self._workflow_run_id = workflow_execution.id
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield start_resp
elif isinstance(
@ -278,13 +286,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -297,27 +303,22 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
workflow_run=workflow_run, event=event
)
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -332,10 +333,10 @@ class WorkflowAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -348,18 +349,13 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
parallel_start_resp = (
self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_start_resp
@ -367,18 +363,13 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
parallel_finish_resp = (
self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_finish_resp
@ -386,16 +377,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_start_resp
@ -403,16 +389,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_next_resp
@ -420,16 +401,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_finish_resp
@ -437,16 +413,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_start_resp
@ -454,16 +425,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
@ -471,16 +437,11 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_finish_resp
@ -491,10 +452,8 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -503,12 +462,12 @@ class WorkflowAppGenerateTaskPipeline:
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
workflow_execution=workflow_execution,
)
session.commit()
@ -520,10 +479,8 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -533,10 +490,12 @@ class WorkflowAppGenerateTaskPipeline:
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
@ -548,26 +507,28 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
error_message=event.error
if isinstance(event, QueueWorkflowFailedEvent)
else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
@ -586,7 +547,7 @@ class WorkflowAppGenerateTaskPipeline:
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_cycle_manager._handle_agent_log(
yield self._workflow_response_converter.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
@ -595,11 +556,9 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
"""
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
assert workflow_run is not None
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from models.workflow import WorkflowNodeExecutionStatus
@ -190,7 +190,7 @@ class WorkflowStartStreamResponse(StreamResponse):
id: str
workflow_id: str
sequence_number: int
inputs: dict
inputs: Mapping[str, Any]
created_at: int
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
@ -212,7 +212,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
workflow_id: str
sequence_number: int
status: str
outputs: Optional[dict] = None
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
elapsed_time: float
total_tokens: int
@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
inputs: Optional[Mapping[str, Any]] = None
created_at: int
extras: dict = {}
parallel_id: Optional[str] = None
@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse):
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse):
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
@ -788,7 +788,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: str
outputs: Optional[dict] = None
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
elapsed_time: float
total_tokens: int

View File

@ -754,7 +754,7 @@ class ProviderConfiguration(BaseModel):
:param only_active: return active model only
:return:
"""
provider_models = self.get_provider_models(model_type, only_active)
provider_models = self.get_provider_models(model_type, only_active, model)
for provider_model in provider_models:
if provider_model.model == model:
@ -763,12 +763,13 @@ class ProviderConfiguration(BaseModel):
return None
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False
self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
) -> list[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
:param only_active: only active models
:param model: model name
:return:
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
@ -791,7 +792,10 @@ class ProviderConfiguration(BaseModel):
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
model_types=model_types,
provider_schema=provider_schema,
model_setting_map=model_setting_map,
model=model,
)
if only_active:
@ -943,6 +947,7 @@ class ProviderConfiguration(BaseModel):
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
model: Optional[str] = None,
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
@ -995,7 +1000,8 @@ class ProviderConfiguration(BaseModel):
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
if model and model != model_configuration.model:
continue
try:
custom_model_schema = self.get_model_schema(
model_type=model_configuration.model_type,

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional, Union
@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel):
description="The status message of the span. Additional field for context of the event. E.g. the error "
"message of an error event.",
)
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
input: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
default=None, description="The input of the span. Can be any JSON object."
)
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
output: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
default=None, description="The output of the span. Can be any JSON object."
)
version: Optional[str] = Field(

View File

@ -1,11 +1,10 @@
import json
import logging
import os
from datetime import datetime, timedelta
from typing import Optional
from langfuse import Langfuse # type: ignore
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
@ -30,8 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
)
from core.ops.utils import filter_none_values
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.model import EndUser
from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@ -113,8 +113,29 @@ class LangFuseDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
@ -124,23 +145,22 @@ class LangFuseDataTrace(BaseTraceInstance):
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
execution_metadata = node_execution.metadata if node_execution.metadata else {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
@ -152,7 +172,7 @@ class LangFuseDataTrace(BaseTraceInstance):
"status": status,
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
process_data = node_execution.process_data if node_execution.process_data else {}
model_provider = process_data.get("model_provider", None)
model_name = process_data.get("model_name", None)
if model_provider is not None and model_name is not None:

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional, Union
@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel):
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
name: Optional[str] = Field(..., description="Name of the run")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run")
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run")
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run")
run_type: LangSmithRunType = Field(..., description="Type of the run")
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
end_time: Optional[datetime | str] = Field(None, description="End time of the run")

View File

@ -1,4 +1,3 @@
import json
import logging
import os
import uuid
@ -7,7 +6,7 @@ from typing import Optional, cast
from langsmith import Client
from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig
@ -29,8 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@ -137,8 +138,29 @@ class LangSmithDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
@ -148,27 +170,23 @@ class LangSmithDataTrace(BaseTraceInstance):
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
node_total_tokens = execution_metadata.get("total_tokens", 0)
metadata = execution_metadata.copy()
execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
metadata = {str(key): value for key, value in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
@ -181,7 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance):
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
process_data = node_execution.process_data if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm
@ -191,7 +209,7 @@ class LangSmithDataTrace(BaseTraceInstance):
"ls_model_name": process_data.get("model_name", ""),
}
)
elif node_type == "knowledge-retrieval":
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
run_type = LangSmithRunType.retriever
else:
run_type = LangSmithRunType.tool

View File

@ -1,4 +1,3 @@
import json
import logging
import os
import uuid
@ -7,7 +6,7 @@ from typing import Optional, cast
from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
@ -23,8 +22,10 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@ -114,6 +115,7 @@ class OpikDataTrace(BaseTraceInstance):
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"thread_id": trace_info.conversation_id,
"tags": ["message", "workflow"],
"project_name": self.project,
}
@ -143,6 +145,7 @@ class OpikDataTrace(BaseTraceInstance):
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"thread_id": trace_info.conversation_id,
"tags": ["workflow"],
"project_name": self.project,
}
@ -150,8 +153,29 @@ class OpikDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
@ -161,26 +185,22 @@ class OpikDataTrace(BaseTraceInstance):
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
metadata = execution_metadata.copy()
execution_metadata = node_execution.metadata if node_execution.metadata else {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
@ -193,7 +213,7 @@ class OpikDataTrace(BaseTraceInstance):
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
process_data = node_execution.process_data if node_execution.process_data else {}
provider = None
model = None
@ -226,7 +246,7 @@ class OpikDataTrace(BaseTraceInstance):
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
if not total_tokens:
total_tokens = execution_metadata.get("total_tokens", 0)
total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
span_data = {
"trace_id": opik_trace_id,
@ -288,6 +308,7 @@ class OpikDataTrace(BaseTraceInstance):
"metadata": wrap_metadata(metadata),
"input": trace_info.inputs,
"output": message_data.answer,
"thread_id": message_data.conversation_id,
"tags": ["message", str(trace_info.conversation_mode)],
"project_name": self.project,
}
@ -402,6 +423,7 @@ class OpikDataTrace(BaseTraceInstance):
"metadata": wrap_metadata(trace_info.metadata),
"input": trace_info.inputs,
"output": trace_info.outputs,
"thread_id": trace_info.conversation_id,
"tags": ["generate_name"],
"project_name": self.project,
}

View File

@ -30,6 +30,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
@ -234,7 +235,11 @@ class OpsTraceManager:
return None
tracing_provider = app_ops_trace_config.get("tracing_provider")
if tracing_provider is None or tracing_provider not in provider_config_map:
if tracing_provider is None:
return None
try:
provider_config_map[tracing_provider]
except KeyError:
return None
# decrypt_token
@ -287,8 +292,14 @@ class OpsTraceManager:
:return:
"""
# auth check
if tracing_provider not in provider_config_map and tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
if enabled == True:
try:
provider_config_map[tracing_provider]
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
else:
if tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
if not app_config:
@ -367,7 +378,7 @@ class TraceTask:
self,
trace_type: Any,
message_id: Optional[str] = None,
workflow_run: Optional[WorkflowRun] = None,
workflow_execution: Optional[WorkflowExecution] = None,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
timer: Optional[Any] = None,
@ -375,7 +386,7 @@ class TraceTask:
):
self.trace_type = trace_type
self.message_id = message_id
self.workflow_run_id = workflow_run.id if workflow_run else None
self.workflow_run_id = workflow_execution.id if workflow_execution else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator
@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel):
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace")
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
None, description="Metadata and attributes associated with trace"
)

View File

@ -1,4 +1,3 @@
import json
import logging
import os
import uuid
@ -7,6 +6,7 @@ from typing import Any, Optional, cast
import wandb
import weave
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig
@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@ -128,58 +130,57 @@ class WeaveDataTrace(BaseTraceInstance):
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
node_total_tokens = execution_metadata.get("total_tokens", 0)
attributes = execution_metadata.copy()
execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update(
{
"workflow_run_id": trace_info.workflow_run_id,
@ -192,7 +193,7 @@ class WeaveDataTrace(BaseTraceInstance):
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
process_data = node_execution.process_data if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat":
attributes.update(
{

View File

@ -64,9 +64,9 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
)
return {
"inputs": execution.inputs_dict,
"outputs": execution.outputs_dict,
"process_data": execution.process_data_dict,
"inputs": execution.inputs,
"outputs": execution.outputs,
"process_data": execution.process_data,
}
@classmethod
@ -113,7 +113,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
)
return {
"inputs": execution.inputs_dict,
"outputs": execution.outputs_dict,
"process_data": execution.process_data_dict,
"inputs": execution.inputs,
"outputs": execution.outputs,
"process_data": execution.process_data,
}

View File

@ -405,7 +405,29 @@ class RetrievalService:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
return [RetrievalSegments(**record) for record in records]
result = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None
# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
float(score_value)
if score_value is not None and isinstance(score_value, int | float | str)
else None
)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
result.append(retrieval_segment)
return result
except Exception as e:
db.session.rollback()
raise e

View File

@ -23,7 +23,8 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
secure: bool = False
secure: bool = False # use_ssl
verify_certs: bool = True
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
@ -42,6 +43,8 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
if not values.get("OPENSEARCH_SECURE") and values.get("OPENSEARCH_VERIFY_CERTS"):
raise ValueError("verify_certs=True requires secure (HTTPS) connection")
return values
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
@ -57,7 +60,7 @@ class OpenSearchConfig(BaseModel):
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
"verify_certs": self.verify_certs,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
@ -279,6 +282,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,

View File

@ -271,12 +271,15 @@ class TencentVector(BaseVector):
for result in res[0]:
meta = result.get(self.field_metadata)
if isinstance(meta, str):
# Compatible with version 1.1.3 and below.
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
score = result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)
return docs
def delete(self) -> None:

View File

@ -6,6 +6,12 @@ from urllib.parse import urljoin
import requests
from requests import Response
from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError,
WaterCrawlBadRequestError,
WaterCrawlPermissionError,
)
class BaseAPIClient:
def __init__(self, api_key, base_url):
@ -53,6 +59,15 @@ class WaterCrawlAPIClient(BaseAPIClient):
yield data
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
if response.status_code == 401:
raise WaterCrawlAuthenticationError(response)
if response.status_code == 403:
raise WaterCrawlPermissionError(response)
if 400 <= response.status_code < 500:
raise WaterCrawlBadRequestError(response)
response.raise_for_status()
if response.status_code == 204:
return None

View File

@ -0,0 +1,32 @@
import json
class WaterCrawlError(Exception):
pass
class WaterCrawlBadRequestError(WaterCrawlError):
def __init__(self, response):
self.status_code = response.status_code
self.response = response
data = response.json()
self.message = data.get("message", "Unknown error occurred")
self.errors = data.get("errors", {})
super().__init__(self.message)
@property
def flat_errors(self):
return json.dumps(self.errors)
def __str__(self):
return f"WaterCrawlBadRequestError: {self.message} \n {self.flat_errors}"
class WaterCrawlPermissionError(WaterCrawlBadRequestError):
def __str__(self):
return f"You are exceeding your WaterCrawl API limits. {self.message}"
class WaterCrawlAuthenticationError(WaterCrawlBadRequestError):
def __str__(self):
return "WaterCrawl API key is invalid or expired. Please check your API key and try again."

View File

@ -19,7 +19,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.model import UploadFile
logger = logging.getLogger(__name__)
@ -116,7 +116,7 @@ class WordExtractor(BaseExtractor):
extension=str(image_ext),
mime_type=mime_type or "",
created_by=self.user_id,
created_by_role=CreatedByRole.ACCOUNT,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=True,
used_by=self.user_id,

View File

@ -190,7 +190,7 @@ class DatasetRetrieval:
retrieve_config.rerank_mode or "reranking_model",
retrieve_config.reranking_model,
retrieve_config.weights,
retrieve_config.reranking_enabled or True,
True if retrieve_config.reranking_enabled is None else retrieve_config.reranking_enabled,
message_id,
metadata_filter_document_ids,
metadata_condition,

View File

@ -0,0 +1,242 @@
"""
SQLAlchemy implementation of the WorkflowExecutionRepository.
"""
import json
import logging
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution_entities import (
WorkflowExecution,
WorkflowExecutionStatus,
WorkflowType,
)
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowRun,
)
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
"""
SQLAlchemy implementation of the WorkflowExecutionRepository interface.
This implementation supports multi-tenancy by filtering operations based on tenant_id.
Each method creates its own session, handles the transaction, and commits changes
to the database. This prevents long-running connections in the workflow core.
This implementation also includes an in-memory cache for workflow executions to improve
performance by reducing database queries.
"""
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom],
):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
# If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize in-memory cache for workflow executions
# Key: execution_id, Value: WorkflowRun (DB model)
self._execution_cache: dict[str, WorkflowRun] = {}
def _to_domain_model(self, db_model: WorkflowRun) -> WorkflowExecution:
"""
Convert a database model to a domain model.
Args:
db_model: The database model to convert
Returns:
The domain model
"""
# Parse JSON fields
inputs = db_model.inputs_dict
outputs = db_model.outputs_dict
graph = db_model.graph_dict
# Convert status to domain enum
status = WorkflowExecutionStatus(db_model.status)
return WorkflowExecution(
id=db_model.id,
workflow_id=db_model.workflow_id,
sequence_number=db_model.sequence_number,
type=WorkflowType(db_model.type),
workflow_version=db_model.version,
graph=graph,
inputs=inputs,
outputs=outputs,
status=status,
error_message=db_model.error or "",
total_tokens=db_model.total_tokens,
total_steps=db_model.total_steps,
exceptions_count=db_model.exceptions_count,
started_at=db_model.created_at,
finished_at=db_model.finished_at,
)
def _to_db_model(self, domain_model: WorkflowExecution) -> WorkflowRun:
"""
Convert a domain model to a database model.
Args:
domain_model: The domain model to convert
Returns:
The database model
"""
# Use values from constructor if provided
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
db_model = WorkflowRun()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
if self._app_id is not None:
db_model.app_id = self._app_id
db_model.workflow_id = domain_model.workflow_id
db_model.triggered_from = self._triggered_from
db_model.sequence_number = domain_model.sequence_number
db_model.type = domain_model.type
db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.total_tokens = domain_model.total_tokens
db_model.total_steps = domain_model.total_steps
db_model.exceptions_count = domain_model.exceptions_count
db_model.created_by_role = self._creator_user_role
db_model.created_by = self._creator_user_id
db_model.created_at = domain_model.started_at
db_model.finished_at = domain_model.finished_at
# Calculate elapsed time if finished_at is available
if domain_model.finished_at:
db_model.elapsed_time = (domain_model.finished_at - domain_model.started_at).total_seconds()
else:
db_model.elapsed_time = 0
return db_model
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution domain entity to the database.
This method serves as a domain-to-database adapter that:
1. Converts the domain entity to its database representation
2. Persists the database model using SQLAlchemy's merge operation
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Updates the in-memory cache for faster subsequent lookups
The method handles both creating new records and updating existing ones through
SQLAlchemy's merge operation.
Args:
execution: The WorkflowExecution domain entity to persist
"""
# Convert domain model to database model using tenant context and other attributes
db_model = self._to_db_model(execution)
# Create a new database session
with self._session_factory() as session:
# SQLAlchemy merge intelligently handles both insert and update operations
# based on the presence of the primary key
session.merge(db_model)
session.commit()
# Update the in-memory cache for faster subsequent lookups
logger.debug(f"Updating cache for execution_id: {db_model.id}")
self._execution_cache[db_model.id] = db_model
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
# First check the cache
if execution_id in self._execution_cache:
logger.debug(f"Cache hit for execution_id: {execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._execution_cache[execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == execution_id,
WorkflowRun.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowRun.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._execution_cache[execution_id] = db_model
# Convert to domain model and return
return self._to_domain_model(db_model)
return None

View File

@ -2,16 +2,31 @@
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
"""
import json
import logging
from collections.abc import Sequence
from typing import Optional
from typing import Optional, Union
from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
)
logger = logging.getLogger(__name__)
@ -23,16 +38,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
This implementation supports multi-tenancy by filtering operations based on tenant_id.
Each method creates its own session, handles the transaction, and commits changes
to the database. This prevents long-running connections in the workflow core.
This implementation also includes an in-memory cache for node executions to improve
performance by reducing database queries.
"""
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
tenant_id: Tenant ID for multi-tenancy
app_id: Optional app ID for filtering by application
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
# If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine):
@ -44,38 +69,167 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
def save(self, execution: WorkflowNodeExecution) -> None:
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize in-memory cache for node executions
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
"""
Save a WorkflowNodeExecution instance and commit changes to the database.
Convert a database model to a domain model.
Args:
execution: The WorkflowNodeExecution instance to save
db_model: The database model to convert
Returns:
The domain model
"""
# Parse JSON fields
inputs = db_model.inputs_dict
process_data = db_model.process_data_dict
outputs = db_model.outputs_dict
metadata = {NodeRunMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()}
# Convert status to domain enum
status = NodeExecutionStatus(db_model.status)
return NodeExecution(
id=db_model.id,
node_execution_id=db_model.node_execution_id,
workflow_id=db_model.workflow_id,
workflow_run_id=db_model.workflow_run_id,
index=db_model.index,
predecessor_node_id=db_model.predecessor_node_id,
node_id=db_model.node_id,
node_type=NodeType(db_model.node_type),
title=db_model.title,
inputs=inputs,
process_data=process_data,
outputs=outputs,
status=status,
error=db_model.error,
elapsed_time=db_model.elapsed_time,
metadata=metadata,
created_at=db_model.created_at,
finished_at=db_model.finished_at,
)
def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
"""
Convert a domain model to a database model.
Args:
domain_model: The domain model to convert
Returns:
The database model
"""
# Use values from constructor if provided
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
db_model = WorkflowNodeExecution()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
if self._app_id is not None:
db_model.app_id = self._app_id
db_model.workflow_id = domain_model.workflow_id
db_model.triggered_from = self._triggered_from
db_model.workflow_run_id = domain_model.workflow_run_id
db_model.index = domain_model.index
db_model.predecessor_node_id = domain_model.predecessor_node_id
db_model.node_execution_id = domain_model.node_execution_id
db_model.node_id = domain_model.node_id
db_model.node_type = domain_model.node_type
db_model.title = domain_model.title
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
db_model.status = domain_model.status
db_model.error = domain_model.error
db_model.elapsed_time = domain_model.elapsed_time
db_model.execution_metadata = (
json.dumps(jsonable_encoder(domain_model.metadata)) if domain_model.metadata else None
)
db_model.created_at = domain_model.created_at
db_model.created_by_role = self._creator_user_role
db_model.created_by = self._creator_user_id
db_model.finished_at = domain_model.finished_at
return db_model
def save(self, execution: NodeExecution) -> None:
"""
Save or update a NodeExecution domain entity to the database.
This method serves as a domain-to-database adapter that:
1. Converts the domain entity to its database representation
2. Persists the database model using SQLAlchemy's merge operation
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Updates the in-memory cache for faster subsequent lookups
The method handles both creating new records and updating existing ones through
SQLAlchemy's merge operation.
Args:
execution: The NodeExecution domain entity to persist
"""
# Convert domain model to database model using tenant context and other attributes
db_model = self.to_db_model(execution)
# Create a new database session
with self._session_factory() as session:
# Ensure tenant_id is set
if not execution.tenant_id:
execution.tenant_id = self._tenant_id
# Set app_id if provided and not already set
if self._app_id and not execution.app_id:
execution.app_id = self._app_id
session.add(execution)
# SQLAlchemy merge intelligently handles both insert and update operations
# based on the presence of the primary key
session.merge(db_model)
session.commit()
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
# Update the in-memory cache for faster subsequent lookups
# Only cache if we have a node_execution_id to use as the cache key
if db_model.node_execution_id:
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
self._node_execution_cache[db_model.node_execution_id] = db_model
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
"""
Retrieve a WorkflowNodeExecution by its node_execution_id.
Retrieve a NodeExecution by its node_execution_id.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
node_execution_id: The node execution ID
Returns:
The WorkflowNodeExecution instance if found, None otherwise
The NodeExecution instance if found, None otherwise
"""
# First check the cache
if node_execution_id in self._node_execution_cache:
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._node_execution_cache[node_execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.node_execution_id == node_execution_id,
@ -85,15 +239,27 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
return session.scalar(stmt)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._node_execution_cache[node_execution_id] = db_model
def get_by_workflow_run(
# Convert to domain model and return
return self._to_domain_model(db_model)
return None
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
This method directly returns database models without converting to domain models,
which is useful when you need to access database-specific fields like triggered_from.
It also updates the in-memory cache with the retrieved models.
Args:
workflow_run_id: The workflow run ID
@ -102,7 +268,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution instances
A list of WorkflowNodeExecution database models
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
@ -129,17 +295,58 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if order_columns:
stmt = stmt.order_by(*order_columns)
return session.scalars(stmt).all()
db_models = session.scalars(stmt).all()
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
# Update the cache with the retrieved DB models
for model in db_models:
if model.node_execution_id:
self._node_execution_cache[model.node_execution_id] = model
return db_models
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[NodeExecution]:
"""
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
Retrieve all NodeExecution instances for a specific workflow run.
This method always queries the database to ensure complete and ordered results,
but updates the cache with any retrieved executions.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of NodeExecution instances
"""
# Get the database models using the new method
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
# Convert database models to domain models
domain_models = []
for model in db_models:
domain_model = self._to_domain_model(model)
domain_models.append(domain_model)
return domain_models
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.
This method queries the database directly and updates the cache with any
retrieved executions that have a node_execution_id.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running WorkflowNodeExecution instances
A list of running NodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
@ -152,26 +359,19 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
return session.scalars(stmt).all()
db_models = session.scalars(stmt).all()
domain_models = []
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance and commit changes to the database.
for model in db_models:
# Update cache if node_execution_id is present
if model.node_execution_id:
self._node_execution_cache[model.node_execution_id] = model
Args:
execution: The WorkflowNodeExecution instance to update
"""
with self._session_factory() as session:
# Ensure tenant_id is set
if not execution.tenant_id:
execution.tenant_id = self._tenant_id
# Convert to domain model
domain_model = self._to_domain_model(model)
domain_models.append(domain_model)
# Set app_id if provided and not already set
if self._app_id and not execution.app_id:
execution.app_id = self._app_id
session.merge(execution)
session.commit()
return domain_models
def clear(self) -> None:
"""
@ -179,6 +379,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
This method deletes all WorkflowNodeExecution records that match the tenant_id
and app_id (if provided) associated with this repository instance.
It also clears the in-memory cache.
"""
with self._session_factory() as session:
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
@ -194,3 +395,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
+ (f" and app {self._app_id}" if self._app_id else "")
)
# Clear the in-memory cache
self._node_execution_cache.clear()
logger.info("Cleared in-memory node execution cache")

View File

@ -32,7 +32,7 @@ from core.tools.errors import (
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.model import Message, MessageFile
@ -339,9 +339,9 @@ class ToolEngine:
url=message.url,
upload_file_id=tool_file_id,
created_by_role=(
CreatedByRole.ACCOUNT
CreatorUserRole.ACCOUNT
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER
else CreatorUserRole.END_USER
),
created_by=user_id,
)

View File

@ -528,7 +528,7 @@ class ToolManager:
yield provider
except Exception:
logger.exception(f"load builtin provider {provider}")
logger.exception(f"load builtin provider {provider_path}")
continue
# set builtin providers loaded
cls._builtin_providers_loaded = True
@ -644,10 +644,10 @@ class ToolManager:
)
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
for workflow_provider in workflow_providers:
try:
workflow_provider_controllers.append(
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
)
except Exception:
# app has been deleted

View File

@ -125,6 +125,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list = []
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
@ -181,7 +182,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
score=record.score,
)
)
retrieval_resource_list = []
if self.return_resource:
for record in records:
segment = record.segment

View File

@ -1,7 +1,9 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from typing import Any, Optional, cast
from flask_login import current_user
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@ -87,7 +89,7 @@ class WorkflowTool(Tool):
result = generator.generate(
app_model=app,
workflow=workflow,
user=self._get_user(user_id),
user=cast("Account | EndUser", current_user),
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
streaming=False,
@ -111,20 +113,6 @@ class WorkflowTool(Tool):
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""
get the user by user id
"""
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()
if not user:
raise ValueError("user not found")
return user
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
"""
fork a new tool with metadata

View File

@ -0,0 +1,7 @@
# The minimal selector length for valid variables.
#
# The first element of the selector is the node id, and the second element is the variable name.
#
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
# to extract part of the variable value.
MIN_SELECTORS_LENGTH = 2

View File

@ -0,0 +1,8 @@
from collections.abc import Iterable, Sequence
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
selectors = [node_id, name]
if paths:
selectors.extend(paths)
return selectors

View File

@ -0,0 +1,98 @@
"""
Domain entities for workflow node execution.
This module contains the domain model for workflow node execution, which is used
by the core workflow module. These models are independent of the storage mechanism
and don't contain implementation details like tenant_id, app_id, etc.
"""
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
class NodeExecutionStatus(StrEnum):
"""
Node Execution Status Enum.
"""
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
RETRY = "retry"
class NodeExecution(BaseModel):
"""
Domain model for workflow node execution.
This model represents the core business entity of a node execution,
without implementation details like tenant_id, app_id, etc.
Note: User/context-specific fields (triggered_from, created_by, created_by_role)
have been moved to the repository implementation to keep the domain model clean.
These fields are still accepted in the constructor for backward compatibility,
but they are not stored in the model.
"""
# Core identification fields
id: str # Unique identifier for this execution record
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
workflow_id: str # ID of the workflow this node belongs to
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
title: str # Display title of the node
# Execution data
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
# Execution state
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
error: Optional[str] = None # Error message if execution failed
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
# Timing information
created_at: datetime # When execution started
finished_at: Optional[datetime] = None # When execution completed
def update_from_mapping(
self,
inputs: Optional[Mapping[str, Any]] = None,
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
) -> None:
"""
Update the model from mappings.
Args:
inputs: The inputs to update
process_data: The process data to update
outputs: The outputs to update
metadata: The metadata to update
"""
if inputs is not None:
self.inputs = dict(inputs)
if process_data is not None:
self.process_data = dict(process_data)
if outputs is not None:
self.outputs = dict(outputs)
if metadata is not None:
self.metadata = dict(metadata)

View File

@ -0,0 +1,91 @@
"""
Domain entities for workflow execution.
Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, Field
class WorkflowType(StrEnum):
"""
Workflow Type Enum for domain layer
"""
WORKFLOW = "workflow"
CHAT = "chat"
class WorkflowExecutionStatus(StrEnum):
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
class WorkflowExecution(BaseModel):
"""
Domain model for workflow execution based on WorkflowRun but without
user, tenant, and app attributes.
"""
id: str = Field(...)
workflow_id: str = Field(...)
workflow_version: str = Field(...)
sequence_number: int = Field(...)
type: WorkflowType = Field(...)
graph: Mapping[str, Any] = Field(...)
inputs: Mapping[str, Any] = Field(...)
outputs: Optional[Mapping[str, Any]] = None
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
error_message: str = Field(default="")
total_tokens: int = Field(default=0)
total_steps: int = Field(default=0)
exceptions_count: int = Field(default=0)
started_at: datetime = Field(...)
finished_at: Optional[datetime] = None
@property
def elapsed_time(self) -> float:
"""
Calculate elapsed time in seconds.
If workflow is not finished, use current time.
"""
end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None)
return (end_time - self.started_at).total_seconds()
@classmethod
def new(
cls,
*,
id: str,
workflow_id: str,
sequence_number: int,
type: WorkflowType,
workflow_version: str,
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
return WorkflowExecution(
id=id,
workflow_id=workflow_id,
sequence_number=sequence_number,
type=type,
workflow_version=workflow_version,
graph=graph,
inputs=inputs,
status=WorkflowExecutionStatus.RUNNING,
started_at=started_at,
)

View File

@ -0,0 +1,42 @@
from typing import Optional, Protocol
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
class WorkflowExecutionRepository(Protocol):
"""
Repository interface for WorkflowExecution.
This interface defines the contract for accessing and manipulating
WorkflowExecution data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and other implementation details should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution instance.
This method handles both creating new records and updating existing ones.
The implementation should determine whether to create or update based on
the execution's ID or other identifying fields.
Args:
execution: The WorkflowExecution instance to save or update
"""
...
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
...

View File

@ -2,12 +2,12 @@ from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from models.workflow import WorkflowNodeExecution
from core.workflow.entities.node_execution_entities import NodeExecution
@dataclass
class OrderConfig:
"""Configuration for ordering WorkflowNodeExecution instances."""
"""Configuration for ordering NodeExecution instances."""
order_by: list[str]
order_direction: Optional[Literal["asc", "desc"]] = None
@ -15,10 +15,10 @@ class OrderConfig:
class WorkflowNodeExecutionRepository(Protocol):
"""
Repository interface for WorkflowNodeExecution.
Repository interface for NodeExecution.
This interface defines the contract for accessing and manipulating
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
NodeExecution data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and trigger sources (triggered_from) should be handled at the implementation level, not in
@ -26,24 +26,28 @@ class WorkflowNodeExecutionRepository(Protocol):
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowNodeExecution) -> None:
def save(self, execution: NodeExecution) -> None:
"""
Save a WorkflowNodeExecution instance.
Save or update a NodeExecution instance.
This method handles both creating new records and updating existing ones.
The implementation should determine whether to create or update based on
the execution's ID or other identifying fields.
Args:
execution: The WorkflowNodeExecution instance to save
execution: The NodeExecution instance to save or update
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
"""
Retrieve a WorkflowNodeExecution by its node_execution_id.
Retrieve a NodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The WorkflowNodeExecution instance if found, None otherwise
The NodeExecution instance if found, None otherwise
"""
...
@ -51,9 +55,9 @@ class WorkflowNodeExecutionRepository(Protocol):
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
) -> Sequence[NodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Retrieve all NodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
@ -62,34 +66,25 @@ class WorkflowNodeExecutionRepository(Protocol):
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution instances
A list of NodeExecution instances
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
"""
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
Retrieve all running NodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running WorkflowNodeExecution instances
"""
...
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to update
A list of running NodeExecution instances
"""
...
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
Clear all NodeExecution records based on implementation-specific criteria.
This method is intended to be used for bulk deletion operations, such as removing
all records associated with a specific app_id and tenant_id in multi-tenant implementations.

File diff suppressed because it is too large Load Diff

View File

@ -3,11 +3,14 @@ import json
import flask_login # type: ignore
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import Unauthorized
from werkzeug.exceptions import NotFound, Unauthorized
import contexts
from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from models.account import Account
from models.model import EndUser
from services.account_service import AccountService
login_manager = flask_login.LoginManager()
@ -17,34 +20,48 @@ login_manager = flask_login.LoginManager()
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_token = request.args.get("_token")
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
else:
auth_token: str | None = None
if auth_header:
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme, auth_token = auth_header.split(maxsplit=1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
else:
auth_token = request.args.get("_token")
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
if request.blueprint in {"console", "inner_api"}:
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
if not user_id:
raise Unauthorized("Invalid Authorization token.")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
elif request.blueprint == "web":
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_logged_in(_sender, user):
"""Called when a user logged in."""
if user:
"""Called when a user logged in.
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
through the load_user method, which calls account.set_tenant_id().
"""
if user and isinstance(user, Account) and user.current_tenant_id:
contexts.tenant_id.set(user.current_tenant_id)

View File

@ -0,0 +1,73 @@
import json
import logging
import flask
import werkzeug.http
from flask import Flask
from flask.signals import request_finished, request_started
from configs import dify_config
_logger = logging.getLogger(__name__)
def _is_content_type_json(content_type: str) -> bool:
if not content_type:
return False
content_type_no_option, _ = werkzeug.http.parse_options_header(content_type)
return content_type_no_option.lower() == "application/json"
def _log_request_started(_sender, **_extra):
"""Log the start of a request."""
if not _logger.isEnabledFor(logging.DEBUG):
return
request = flask.request
if not (_is_content_type_json(request.content_type) and request.data):
_logger.debug("Received Request %s -> %s", request.method, request.path)
return
try:
json_data = json.loads(request.data)
except (TypeError, ValueError):
_logger.exception("Failed to parse JSON request")
return
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
_logger.debug(
"Received Request %s -> %s, Request Body:\n%s",
request.method,
request.path,
formatted_json,
)
def _log_request_finished(_sender, response, **_extra):
"""Log the end of a request."""
if not _logger.isEnabledFor(logging.DEBUG) or response is None:
return
if not _is_content_type_json(response.content_type):
_logger.debug("Response %s %s", response.status, response.content_type)
return
response_data = response.get_data(as_text=True)
try:
json_data = json.loads(response_data)
except (TypeError, ValueError):
_logger.exception("Failed to parse JSON response")
return
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
_logger.debug(
"Response %s %s, Response Body:\n%s",
response.status,
response.content_type,
formatted_json,
)
def init_app(app: Flask):
"""Initialize the request logging extension."""
if not dify_config.ENABLE_REQUEST_LOGGING:
return
request_started.connect(_log_request_started, app)
request_finished.connect(_log_request_finished, app)

View File

@ -63,6 +63,7 @@ app_detail_fields = {
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"access_mode": fields.String,
}
prompt_config_fields = {
@ -98,6 +99,7 @@ app_partial_fields = {
"updated_by": fields.String,
"updated_at": TimestampField,
"tags": fields.List(fields.Nested(tag_fields)),
"access_mode": fields.String,
}
@ -176,6 +178,7 @@ app_detail_fields_with_site = {
"updated_by": fields.String,
"updated_at": TimestampField,
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
"access_mode": fields.String,
}

View File

@ -0,0 +1,51 @@
"""add WorkflowDraftVariable model
Revision ID: 2adcbe1f5dfb
Revises: d28f2004b072
Create Date: 2025-05-15 15:31:03.128680
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "2adcbe1f5dfb"
down_revision = "d28f2004b072"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"workflow_draft_variables",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("last_edited_at", sa.DateTime(), nullable=True),
sa.Column("node_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.String(length=255), nullable=False),
sa.Column("selector", sa.String(length=255), nullable=False),
sa.Column("value_type", sa.String(length=20), nullable=False),
sa.Column("value", sa.Text(), nullable=False),
sa.Column("visible", sa.Boolean(), nullable=False),
sa.Column("editable", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")),
sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
# Dropping `workflow_draft_variables` also drops any index associated with it.
op.drop_table("workflow_draft_variables")
# ### end Alembic commands ###

View File

@ -27,7 +27,7 @@ from .dataset import (
Whitelist,
)
from .engine import db
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
ApiRequest,
ApiToken,
@ -112,7 +112,7 @@ __all__ = [
"CeleryTaskSet",
"Conversation",
"ConversationVariable",
"CreatedByRole",
"CreatorUserRole",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding",
"Dataset",

View File

@ -1,10 +1,10 @@
import enum
import json
from typing import cast
from typing import Optional, cast
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base
@ -12,6 +12,66 @@ from .engine import db
from .types import StringUUID
class TenantAccountRole(enum.StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
NORMAL = "normal"
DATASET_OPERATOR = "dataset_operator"
@staticmethod
def is_valid_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
@staticmethod
def is_admin_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN
@staticmethod
def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_editing_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod
def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR,
}
class AccountStatus(enum.StrEnum):
PENDING = "pending"
UNINITIALIZED = "uninitialized"
@ -41,24 +101,27 @@ class Account(UserMixin, Base):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@reconstructor
def init_on_load(self):
self.role: Optional[TenantAccountRole] = None
self._current_tenant: Optional[Tenant] = None
@property
def is_password_set(self):
return self.password is not None
@property
def current_tenant(self):
return self._current_tenant # type: ignore
return self._current_tenant
@current_tenant.setter
def current_tenant(self, value: "Tenant"):
tenant = value
def current_tenant(self, tenant: "Tenant"):
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
if ta:
tenant.current_role = ta.role
else:
tenant = None # type: ignore
self._current_tenant = tenant
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant
return
self._current_tenant = None
@property
def current_tenant_id(self) -> str | None:
@ -80,12 +143,12 @@ class Account(UserMixin, Base):
return
tenant, join = tenant_account_join
tenant.current_role = join.role
self.role = join.role
self._current_tenant = tenant
@property
def current_role(self):
return self._current_tenant.current_role
return self.role
def get_status(self) -> AccountStatus:
status_str = self.status
@ -105,23 +168,23 @@ class Account(UserMixin, Base):
# check current_user.current_tenant.current_role in ['admin', 'owner']
@property
def is_admin_or_owner(self):
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
return TenantAccountRole.is_privileged_role(self.role)
@property
def is_admin(self):
return TenantAccountRole.is_admin_role(self._current_tenant.current_role)
return TenantAccountRole.is_admin_role(self.role)
@property
def is_editor(self):
return TenantAccountRole.is_editing_role(self._current_tenant.current_role)
return TenantAccountRole.is_editing_role(self.role)
@property
def is_dataset_editor(self):
return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role)
return TenantAccountRole.is_dataset_edit_role(self.role)
@property
def is_dataset_operator(self):
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR
return self.role == TenantAccountRole.DATASET_OPERATOR
class TenantStatus(enum.StrEnum):
@ -129,66 +192,6 @@ class TenantStatus(enum.StrEnum):
ARCHIVE = "archive"
class TenantAccountRole(enum.StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
NORMAL = "normal"
DATASET_OPERATOR = "dataset_operator"
@staticmethod
def is_valid_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_privileged_role(role: str) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
@staticmethod
def is_admin_role(role: str) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN
@staticmethod
def is_non_owner_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_editing_role(role: str) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod
def is_dataset_edit_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR,
}
class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)

View File

@ -1,5 +1,7 @@
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase
from models.engine import metadata
Base = declarative_base(metadata=metadata)
class Base(DeclarativeBase):
metadata = metadata

View File

@ -1,7 +1,7 @@
from enum import StrEnum
class CreatedByRole(StrEnum):
class CreatorUserRole(StrEnum):
ACCOUNT = "account"
END_USER = "end_user"
@ -14,3 +14,10 @@ class UserFrom(StrEnum):
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
APP_RUN = "app-run"
class DraftVariableType(StrEnum):
# node means that the correspond variable
NODE = "node"
SYS = "sys"
CONVERSATION = "conversation"

View File

@ -29,7 +29,7 @@ from libs.helper import generate_string
from .account import Account, Tenant
from .base import Base
from .engine import db
from .enums import CreatedByRole
from .enums import CreatorUserRole
from .types import StringUUID
from .workflow import WorkflowRunStatus
@ -1270,7 +1270,7 @@ class MessageFile(Base):
url: str | None = None,
belongs_to: Literal["user", "assistant"] | None = None,
upload_file_id: str | None = None,
created_by_role: CreatedByRole,
created_by_role: CreatorUserRole,
created_by: str,
):
self.message_id = message_id
@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin):
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(255), nullable=False)
external_user_id = db.Column(db.String(255), nullable=True)
@ -1547,7 +1547,7 @@ class UploadFile(Base):
size: int,
extension: str,
mime_type: str,
created_by_role: CreatedByRole,
created_by_role: CreatorUserRole,
created_by: str,
created_at: datetime,
used: bool,

View File

@ -172,10 +172,6 @@ class WorkflowToolProvider(Base):
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
@property
def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()

View File

@ -1,4 +1,7 @@
from sqlalchemy import CHAR, TypeDecorator
import enum
from typing import Generic, TypeVar
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
@ -24,3 +27,51 @@ class StringUUID(TypeDecorator):
if value is None:
return value
return str(value)
_E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator, Generic[_E]):
impl = VARCHAR
cache_ok = True
_length: int
_enum_class: type[_E]
def __init__(self, enum_class: type[_E], length: int | None = None):
self._enum_class = enum_class
max_enum_value_len = max(len(e.value) for e in enum_class)
if length is not None:
if length < max_enum_value_len:
raise ValueError("length should be greater than enum value length.")
self._length = length
else:
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect):
if value is None:
return value
if isinstance(value, self._enum_class):
return value.value
elif isinstance(value, str):
self._enum_class(value)
return value
else:
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value, dialect) -> _E | None:
if value is None:
return value
if not isinstance(value, str):
raise TypeError(f"expected str, got {type(value)}")
return self._enum_class(value)
def compare_values(self, x, y):
if x is None or y is None:
return x is y
return x == y

View File

@ -1,29 +1,36 @@
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
from core.variables import utils as variable_utils
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
if TYPE_CHECKING:
from models.model import AppMode
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy import UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
import contexts
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Variable
from core.variables import SecretVariable, Segment, SegmentType, Variable
from factories import variable_factory
from libs import helper
from .account import Account
from .base import Base
from .engine import db
from .enums import CreatedByRole
from .types import StringUUID
from .enums import CreatorUserRole, DraftVariableType
from .types import EnumText, StringUUID
_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from models.model import AppMode
@ -143,7 +150,7 @@ class Workflow(Base):
conversation_variables: Sequence[Variable],
marked_name: str = "",
marked_comment: str = "",
) -> Self:
) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@ -192,7 +199,9 @@ class Workflow(Base):
features["file_upload"]["number_limits"] = image_number_limits
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
features["file_upload"]["allowed_file_extensions"] = []
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
"allowed_file_extensions", []
)
del features["file_upload"]["image"]
self._features = json.dumps(features)
return self._features
@ -418,29 +427,29 @@ class WorkflowRun(Base):
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0"))
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at = db.Column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@property
def graph_dict(self):
def graph_dict(self) -> Mapping[str, Any]:
return json.loads(self.graph) if self.graph else {}
@property
@ -634,24 +643,24 @@ class WorkflowNodeExecution(Base):
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
created_by_role = CreatorUserRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
created_by_role = CreatorUserRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@property
def inputs_dict(self):
return json.loads(self.inputs) if self.inputs else None
@property
def outputs_dict(self):
def outputs_dict(self) -> dict[str, Any] | None:
return json.loads(self.outputs) if self.outputs else None
@property
@ -659,8 +668,11 @@ class WorkflowNodeExecution(Base):
return json.loads(self.process_data) if self.process_data else None
@property
def execution_metadata_dict(self):
return json.loads(self.execution_metadata) if self.execution_metadata else None
def execution_metadata_dict(self) -> dict[str, Any]:
# When the metadata is unset, we return an empty dictionary instead of `None`.
# This approach streamlines the logic for the caller, making it easier to handle
# cases where metadata is absent.
return json.loads(self.execution_metadata) if self.execution_metadata else {}
@property
def extras(self):
@ -742,12 +754,12 @@ class WorkflowAppLog(Base):
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id = db.Column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_from: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def workflow_run(self):
@ -755,15 +767,15 @@ class WorkflowAppLog(Base):
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
class ConversationVariable(Base):
@ -772,9 +784,11 @@ class ConversationVariable(Base):
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True)
updated_at = mapped_column(
data: Mapped[str] = mapped_column(db.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@ -797,3 +811,201 @@ class ConversationVariable(Base):
def to_variable(self) -> Variable:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
# Only `sys.query` and `sys.files` could be modified.
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
def _naive_utc_datetime():
return datetime.now(UTC).replace(tzinfo=None)
class WorkflowDraftVariable(Base):
@staticmethod
def unique_columns() -> list[str]:
return [
"app_id",
"node_id",
"name",
]
__tablename__ = "workflow_draft_variables"
__table_args__ = (UniqueConstraint(*unique_columns()),)
# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
created_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
# "`app_id` maps to the `id` field in the `model.App` model."
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# `last_edited_at` records when the value of a given draft variable
# is edited.
#
# If it's not edited after creation, its value is `None`.
last_edited_at: Mapped[datetime | None] = mapped_column(
db.DateTime,
nullable=True,
default=None,
)
# The `node_id` field is special.
#
# If the variable is a conversation variable or a system variable, then the value of `node_id`
# is `conversation` or `sys`, respective.
#
# Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is
# the identity of correspond node in graph definition. An example of node id is `"1745769620734"`.
#
# However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other
# "Answer" node conform the rules above.)
node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id")
# From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than
# 80 chars.
#
# ref: api/core/workflow/entities/variable_pool.py:18
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(
sa.String(255),
default="",
nullable=False,
)
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
# JSON string
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
# visible
visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
if not isinstance(selector, list):
_logger.error(
"invalid selector loaded from database, type=%s, value=%s",
type(selector),
self.selector,
)
raise ValueError("invalid selector.")
return selector
def _set_selector(self, value: list[str]):
self.selector = json.dumps(value)
def get_value(self) -> Segment | None:
return build_segment(json.loads(self.value))
def set_name(self, name: str):
self.name = name
self._set_selector([self.node_id, name])
def set_value(self, value: Segment):
self.value = json.dumps(value.value)
self.value_type = value.value_type
def get_node_id(self) -> str | None:
if self.get_variable_type() == DraftVariableType.NODE:
return self.node_id
else:
return None
def get_variable_type(self) -> DraftVariableType:
match self.node_id:
case DraftVariableType.CONVERSATION:
return DraftVariableType.CONVERSATION
case DraftVariableType.SYS:
return DraftVariableType.SYS
case _:
return DraftVariableType.NODE
@classmethod
def _new(
cls,
*,
app_id: str,
node_id: str,
name: str,
value: Segment,
description: str = "",
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.created_at = _naive_utc_datetime()
variable.updated_at = _naive_utc_datetime()
variable.description = description
variable.app_id = app_id
variable.node_id = node_id
variable.name = name
variable.set_value(value)
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
return variable
@classmethod
def new_conversation_variable(
cls,
*,
app_id: str,
name: str,
value: Segment,
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
name=name,
value=value,
)
return variable
@classmethod
def new_sys_variable(
cls,
*,
app_id: str,
name: str,
value: Segment,
editable: bool = False,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value)
variable.editable = editable
return variable
@classmethod
def new_node_variable(
cls,
*,
app_id: str,
node_id: str,
name: str,
value: Segment,
visible: bool = True,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
variable.visible = visible
variable.editable = True
return variable
@property
def edited(self):
return self.last_edited_at is not None
def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE

View File

@ -39,7 +39,7 @@ dependencies = [
"oci~=2.135.1",
"openai~=1.61.0",
"openpyxl~=3.1.5",
"opik~=1.3.4",
"opik~=1.7.25",
"opentelemetry-api==1.27.0",
"opentelemetry-distro==0.48b0",
"opentelemetry-exporter-otlp==1.27.0",
@ -72,7 +72,7 @@ dependencies = [
"python-dotenv==1.0.1",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=6.0.0",
"redis[hiredis]~=6.1.0",
"resend~=2.9.0",
"sentry-sdk[flask]~=2.28.0",
"sqlalchemy~=2.0.29",
@ -148,6 +148,7 @@ dev = [
"types-tensorflow~=2.18.0",
"types-tqdm~=4.67.0",
"types-ujson~=5.10.0",
"boto3-stubs>=1.38.20",
]
############################################################

View File

@ -49,7 +49,7 @@ from services.errors.account import (
RoleAlreadyAssignedError,
TenantNotFoundError,
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
from tasks.delete_account_task import delete_account_task
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
@ -628,6 +628,10 @@ class TenantService:
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
raise WorkSpaceNotAllowedCreateError()
workspaces = FeatureService.get_system_features().license.workspaces
if not workspaces.is_available():
raise WorkspacesLimitExceededError()
if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else:
@ -928,7 +932,11 @@ class RegisterService:
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
if (
FeatureService.get_system_features().is_allow_create_workspace
and create_workspace_required
and FeatureService.get_system_features().license.workspaces.is_available()
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant

View File

@ -18,8 +18,10 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode, AppModelConfig
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.tag_service import TagService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
@ -155,6 +157,10 @@ class AppService:
app_was_created.send(app, account=account)
if FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private")
return app
def get_app(self, app: App) -> App:
@ -307,6 +313,10 @@ class AppService:
db.session.delete(app)
db.session.commit()
# clean up web app settings
if FeatureService.get_system_features().webapp_auth.enabled:
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
@ -373,3 +383,15 @@ class AppService:
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
return meta
@staticmethod
def get_app_code_by_id(app_id: str) -> str:
"""
Get app code by app id
:param app_id: app id
:return: app code
"""
site = db.session.query(Site).filter(Site.app_id == app_id).first()
if not site:
raise ValueError(f"App with id {app_id} not found")
return str(site.code)

View File

@ -960,11 +960,11 @@ class DocumentService:
"score_threshold_enabled": False,
}
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id:
@ -992,7 +992,7 @@ class DocumentService:
created_by=account.id,
)
else:
logging.warn(
logging.warning(
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
)
return

View File

@ -1,11 +1,90 @@
from pydantic import BaseModel, Field
from services.enterprise.base import EnterpriseRequest
class WebAppSettings(BaseModel):
access_mode: str = Field(
description="Access mode for the web app. Can be 'public' or 'private'",
default="private",
alias="accessMode",
)
class EnterpriseService:
@classmethod
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")
@classmethod
def get_app_web_sso_enabled(cls, app_code):
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):
params = {"userId": user_id, "appCode": app_code}
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
return data.get("result", False)
@classmethod
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
if not app_id:
raise ValueError("app_id must be provided.")
params = {"appId": app_id}
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
@classmethod
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
if not app_ids:
return {}
body = {"appIds": app_ids}
data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body)
if not data:
raise ValueError("No data found.")
if not isinstance(data["accessModes"], dict):
raise ValueError("Invalid data format.")
ret = {}
for key, value in data["accessModes"].items():
curr = WebAppSettings()
curr.access_mode = value
ret[key] = curr
return ret
@classmethod
def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
if not app_code:
raise ValueError("app_code must be provided.")
params = {"appCode": app_code}
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
@classmethod
def update_app_access_mode(cls, app_id: str, access_mode: str):
if not app_id:
raise ValueError("app_id must be provided.")
if access_mode not in ["public", "private", "private_all"]:
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
data = {"appId": app_id, "accessMode": access_mode}
response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data)
return response.get("result", False)
@classmethod
def cleanup_webapp(cls, app_id: str):
if not app_id:
raise ValueError("app_id must be provided.")
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)

View File

@ -0,0 +1,18 @@
from pydantic import BaseModel
from tasks.mail_enterprise_task import send_enterprise_email_task
class DifyMail(BaseModel):
to: list[str]
subject: str
body: str
substitutions: dict[str, str] = {}
class EnterpriseMailService:
@classmethod
def send_mail(cls, mail: DifyMail):
send_enterprise_email_task.delay(
to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions
)

View File

@ -7,3 +7,7 @@ class WorkSpaceNotAllowedCreateError(BaseServiceError):
class WorkSpaceNotFoundError(BaseServiceError):
pass
class WorkspacesLimitExceededError(BaseServiceError):
pass

View File

@ -1,6 +1,6 @@
from enum import StrEnum
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from services.billing_service import BillingService
@ -27,6 +27,32 @@ class LimitationModel(BaseModel):
limit: int = 0
class LicenseLimitationModel(BaseModel):
"""
- enabled: whether this limit is enforced
- size: current usage count
- limit: maximum allowed count; 0 means unlimited
"""
enabled: bool = Field(False, description="Whether this limit is currently active")
size: int = Field(0, description="Number of resources already consumed")
limit: int = Field(0, description="Maximum number of resources allowed; 0 means no limit")
def is_available(self, required: int = 1) -> bool:
"""
Determine whether the requested amount can be allocated.
Returns True if:
- this limit is not active, or
- the limit is zero (unlimited), or
- there is enough remaining quota.
"""
if not self.enabled or self.limit == 0:
return True
return (self.limit - self.size) >= required
class LicenseStatus(StrEnum):
NONE = "none"
INACTIVE = "inactive"
@ -39,6 +65,27 @@ class LicenseStatus(StrEnum):
class LicenseModel(BaseModel):
status: LicenseStatus = LicenseStatus.NONE
expired_at: str = ""
workspaces: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
class BrandingModel(BaseModel):
enabled: bool = False
application_title: str = ""
login_page_logo: str = ""
workspace_logo: str = ""
favicon: str = ""
class WebAppAuthSSOModel(BaseModel):
protocol: str = ""
class WebAppAuthModel(BaseModel):
enabled: bool = False
allow_sso: bool = False
sso_config: WebAppAuthSSOModel = WebAppAuthSSOModel()
allow_email_code_login: bool = False
allow_email_password_login: bool = False
class FeatureModel(BaseModel):
@ -54,6 +101,8 @@ class FeatureModel(BaseModel):
can_replace_logo: bool = False
model_load_balancing_enabled: bool = False
dataset_operator_enabled: bool = False
webapp_copyright_enabled: bool = False
workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@ -68,9 +117,6 @@ class KnowledgeRateLimitModel(BaseModel):
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = ""
enable_web_sso_switch_component: bool = False
enable_marketplace: bool = False
max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE
enable_email_code_login: bool = False
@ -80,6 +126,8 @@ class SystemFeatureModel(BaseModel):
is_allow_create_workspace: bool = False
is_email_setup: bool = False
license: LicenseModel = LicenseModel()
branding: BrandingModel = BrandingModel()
webapp_auth: WebAppAuthModel = WebAppAuthModel()
class FeatureService:
@ -92,6 +140,10 @@ class FeatureService:
if dify_config.BILLING_ENABLED and tenant_id:
cls._fulfill_params_from_billing_api(features, tenant_id)
if dify_config.ENTERPRISE_ENABLED:
features.webapp_copyright_enabled = True
cls._fulfill_params_from_workspace_info(features, tenant_id)
return features
@classmethod
@ -111,8 +163,8 @@ class FeatureService:
cls._fulfill_system_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:
@ -136,6 +188,14 @@ class FeatureService:
features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED
features.education.enabled = dify_config.EDUCATION_ENABLED
@classmethod
def _fulfill_params_from_workspace_info(cls, features: FeatureModel, tenant_id: str):
workspace_info = EnterpriseService.get_workspace_info(tenant_id)
if "WorkspaceMembers" in workspace_info:
features.workspace_members.size = workspace_info["WorkspaceMembers"]["used"]
features.workspace_members.limit = workspace_info["WorkspaceMembers"]["limit"]
features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"]
@classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
billing_info = BillingService.get_info(tenant_id)
@ -145,6 +205,9 @@ class FeatureService:
features.billing.subscription.interval = billing_info["subscription"]["interval"]
features.education.activated = billing_info["subscription"].get("education", False)
if features.billing.subscription.plan != "sandbox":
features.webapp_copyright_enabled = True
if "members" in billing_info:
features.members.size = billing_info["members"]["size"]
features.members.limit = billing_info["members"]["limit"]
@ -178,38 +241,53 @@ class FeatureService:
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
@classmethod
def _fulfill_params_from_enterprise(cls, features):
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
enterprise_info = EnterpriseService.get_info()
if "sso_enforced_for_signin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
if "SSOEnforcedForSignin" in enterprise_info:
features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"]
if "sso_enforced_for_signin_protocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
if "SSOEnforcedForSigninProtocol" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"]
if "sso_enforced_for_web" in enterprise_info:
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
if "EnableEmailCodeLogin" in enterprise_info:
features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"]
if "sso_enforced_for_web_protocol" in enterprise_info:
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
if "EnableEmailPasswordLogin" in enterprise_info:
features.enable_email_password_login = enterprise_info["EnableEmailPasswordLogin"]
if "enable_email_code_login" in enterprise_info:
features.enable_email_code_login = enterprise_info["enable_email_code_login"]
if "IsAllowRegister" in enterprise_info:
features.is_allow_register = enterprise_info["IsAllowRegister"]
if "enable_email_password_login" in enterprise_info:
features.enable_email_password_login = enterprise_info["enable_email_password_login"]
if "IsAllowCreateWorkspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["IsAllowCreateWorkspace"]
if "is_allow_register" in enterprise_info:
features.is_allow_register = enterprise_info["is_allow_register"]
if "Branding" in enterprise_info:
features.branding.application_title = enterprise_info["Branding"].get("applicationTitle", "")
features.branding.login_page_logo = enterprise_info["Branding"].get("loginPageLogo", "")
features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "")
features.branding.favicon = enterprise_info["Branding"].get("favicon", "")
if "is_allow_create_workspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"]
if "WebAppAuth" in enterprise_info:
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False)
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get(
"allowEmailCodeLogin", False
)
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get(
"allowEmailPasswordLogin", False
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
if "license" in enterprise_info:
license_info = enterprise_info["license"]
if "License" in enterprise_info:
license_info = enterprise_info["License"]
if "status" in license_info:
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
if "expired_at" in license_info:
features.license.expired_at = license_info["expired_at"]
if "expiredAt" in license_info:
features.license.expired_at = license_info["expiredAt"]
if "workspaces" in license_info:
features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
features.license.workspaces.limit = license_info["workspaces"]["limit"]
features.license.workspaces.size = license_info["workspaces"]["used"]

View File

@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Account
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
from .errors.file import FileTooLargeError, UnsupportedFileTypeError
@ -81,7 +81,7 @@ class FileService:
size=file_size,
extension=extension,
mime_type=mimetype,
created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER),
created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER),
created_by=user.id,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=False,
@ -133,7 +133,7 @@ class FileService:
extension="txt",
mime_type="text/plain",
created_by=current_user.id,
created_by_role=CreatedByRole.ACCOUNT,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=True,
used_by=current_user.id,

View File

@ -87,7 +87,9 @@ class OpsService:
:param tracing_config: tracing config
:return:
"""
if tracing_provider not in provider_config_map and tracing_provider:
try:
provider_config_map[tracing_provider]
except KeyError:
return {"error": f"Invalid tracing provider: {tracing_provider}"}
config_class, other_keys = (
@ -150,7 +152,9 @@ class OpsService:
:param tracing_config: tracing config
:return:
"""
if tracing_provider not in provider_config_map:
try:
provider_config_map[tracing_provider]
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
# check if trace config already exists

View File

@ -23,11 +23,10 @@ class VectorService:
):
documents: list[Document] = []
document: Document | None = None
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not document:
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
segment.document_id,
@ -37,7 +36,7 @@ class VectorService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@ -61,9 +60,11 @@ class VectorService:
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
cls.generate_child_chunks(
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
)
else:
document = Document(
rag_document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
@ -72,7 +73,7 @@ class VectorService:
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
documents.append(rag_document)
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)

View File

@ -0,0 +1,141 @@
import random
from datetime import UTC, datetime, timedelta
from typing import Any, Optional, cast
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web.error import WebAppAuthAccessDeniedError
from extensions.ext_database import db
from libs.helper import TokenManager
from libs.passport import PassportService
from libs.password import compare_password
from models.account import Account, AccountStatus
from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.feature_service import FeatureService
from tasks.mail_email_code_login import send_email_code_login_mail_task
class WebAppAuthService:
"""Service for web app authentication."""
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = Account.query.filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountPasswordError("Invalid email or password.")
return cast(Account, account)
@classmethod
def login(cls, account: Account, app_code: str, end_user_id: str) -> str:
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
return access_token
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).filter(Account.email == email).first()
if not account:
return None
if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.")
return account
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
email = account.email if account else email
if email is None:
raise ValueError("Email must be provided.")
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=language,
to=account.email if account else email,
code=code,
)
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "webapp_email_code_login")
@classmethod
def revoke_email_code_login_token(cls, token: str):
TokenManager.revoke_token(token, "webapp_email_code_login")
@classmethod
def create_end_user(cls, app_code, email) -> EndUser:
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model:
raise NotFound("App not found.")
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=False,
session_id=email,
name="enterpriseuser",
external_user_id="enterpriseuser",
)
db.session.add(end_user)
db.session.commit()
return end_user
@classmethod
def _validate_user_accessibility(cls, account: Account, app_code: str):
"""Check if the user is allowed to access the app."""
system_features = FeatureService.get_system_features()
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if (
app_settings.access_mode != "public"
and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
):
raise WebAppAuthAccessDeniedError()
@classmethod
def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp())
payload = {
"iss": site.id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": account.id,
"end_user_id": end_user_id,
"token_source": "webapp",
"exp": exp,
}
token: str = PassportService().issue(payload)
return token

View File

@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from models import App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatedByRole
from models.enums import CreatorUserRole
from models.workflow import WorkflowRunStatus
@ -58,7 +58,7 @@ class WorkflowAppService:
stmt = stmt.outerjoin(
EndUser,
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER),
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER),
).where(or_(*keyword_conditions))
if status:

View File

@ -1,4 +1,5 @@
import threading
from collections.abc import Sequence
from typing import Optional
import contexts
@ -6,12 +7,15 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.model import App
from models.workflow import (
from models import (
Account,
App,
EndUser,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunTriggeredFrom,
)
from models.workflow import WorkflowNodeExecutionTriggeredFrom
class WorkflowRunService:
@ -116,7 +120,12 @@ class WorkflowRunService:
return workflow_run
def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]:
def get_workflow_run_node_executions(
self,
app_model: App,
run_id: str,
user: Account | EndUser,
) -> Sequence[WorkflowNodeExecution]:
"""
Get workflow run node execution list
"""
@ -128,13 +137,17 @@ class WorkflowRunService:
if not workflow_run:
return []
# Use the repository to get the node executions
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
session_factory=db.engine,
user=user,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Use the repository to get the node executions with ordering
# Use the repository to get the database models directly
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
workflow_node_executions = repository.get_db_models_by_workflow_run(
workflow_run_id=run_id, order_config=order_config
)
return list(node_executions)
return workflow_node_executions

View File

@ -10,10 +10,10 @@ from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes import NodeType
@ -26,7 +26,6 @@ from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import (
@ -268,33 +267,37 @@ class WorkflowService:
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
start_at=start_at,
tenant_id=app_model.tenant_id,
node_id=node_id,
)
workflow_node_execution.app_id = app_model.id
workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id
# Set workflow_id on the NodeExecution
node_execution.workflow_id = draft_workflow.id
# Use the repository to save the workflow node execution
# Create repository and save the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
session_factory=db.engine,
user=account,
app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
repository.save(workflow_node_execution)
repository.save(node_execution)
# Convert node_execution to WorkflowNodeExecution after save
workflow_node_execution = repository.to_db_model(node_execution)
return workflow_node_execution
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
) -> NodeExecution:
"""
Run draft workflow node
"""
@ -302,7 +305,7 @@ class WorkflowService:
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.run_free_node(
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
tenant_id=tenant_id,
@ -310,7 +313,6 @@ class WorkflowService:
user_inputs=user_inputs,
),
start_at=start_at,
tenant_id=tenant_id,
node_id=node_id,
)
@ -318,21 +320,12 @@ class WorkflowService:
def _handle_node_run_result(
self,
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
start_at: float,
tenant_id: str,
node_id: str,
) -> WorkflowNodeExecution:
"""
Handle node run result
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
:param start_at: float
:param tenant_id: str
:param node_id: str
"""
) -> NodeExecution:
try:
node_instance, generator = getter()
node_instance, generator = invoke_node_fn()
node_run_result: NodeRunResult | None = None
for event in generator:
@ -381,20 +374,21 @@ class WorkflowService:
node_run_result = None
error = e.error
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = tenant_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
workflow_node_execution.index = 1
workflow_node_execution.node_id = node_id
workflow_node_execution.node_type = node_instance.node_type
workflow_node_execution.title = node_instance.node_data.title
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
# Create a NodeExecution domain model
node_execution = NodeExecution(
id=str(uuid4()),
workflow_id="", # This is a single-step execution, so no workflow ID
index=1,
node_id=node_id,
node_type=node_instance.node_type,
title=node_instance.node_data.title,
elapsed_time=time.perf_counter() - start_at,
created_at=datetime.now(UTC).replace(tzinfo=None),
finished_at=datetime.now(UTC).replace(tzinfo=None),
)
if run_succeeded and node_run_result:
# create workflow node execution
# Set inputs, process_data, and outputs as dictionaries (not JSON strings)
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
process_data = (
WorkflowEntry.handle_special_values(node_run_result.process_data)
@ -403,23 +397,23 @@ class WorkflowService:
)
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
workflow_node_execution.inputs = json.dumps(inputs)
workflow_node_execution.process_data = json.dumps(process_data)
workflow_node_execution.outputs = json.dumps(outputs)
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
)
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
workflow_node_execution.error = node_run_result.error
else:
# create workflow node execution
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
node_execution.inputs = inputs
node_execution.process_data = process_data
node_execution.outputs = outputs
node_execution.metadata = node_run_result.metadata
return workflow_node_execution
# Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
node_execution.status = NodeExecutionStatus.SUCCEEDED
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
node_execution.status = NodeExecutionStatus.EXCEPTION
node_execution.error = node_run_result.error
else:
# Set failed status and error
node_execution.status = NodeExecutionStatus.FAILED
node_execution.error = error
return node_execution
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
"""
@ -514,11 +508,11 @@ class WorkflowService:
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
# Check if this workflow is currently referenced by an app
stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(stmt)
app_stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(app_stmt)
if app:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
# Check if there's a tool provider using this specific workflow version

View File

@ -111,7 +111,7 @@ def add_document_to_index_task(dataset_document_id: str):
logging.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
dataset_document.status = "error"
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
db.session.commit()
finally:

View File

@ -6,6 +6,7 @@ from celery import shared_task # type: ignore
from flask import render_template
from extensions.ext_mail import mail
from services.feature_service import FeatureService
@shared_task(queue="mail")
@ -25,10 +26,24 @@ def send_email_code_login_mail_task(language: str, to: str, code: str):
# send email code login mail using different languages
try:
if language == "zh-Hans":
html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code)
template = "email_code_login_mail_template_zh-CN.html"
system_features = FeatureService.get_system_features()
if system_features.branding.enabled:
application_title = system_features.branding.application_title
template = "without-brand/email_code_login_mail_template_zh-CN.html"
html_content = render_template(template, to=to, code=code, application_title=application_title)
else:
html_content = render_template(template, to=to, code=code)
mail.send(to=to, subject="邮箱验证码", html=html_content)
else:
html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code)
template = "email_code_login_mail_template_en-US.html"
system_features = FeatureService.get_system_features()
if system_features.branding.enabled:
application_title = system_features.branding.application_title
template = "without-brand/email_code_login_mail_template_en-US.html"
html_content = render_template(template, to=to, code=code, application_title=application_title)
else:
html_content = render_template(template, to=to, code=code)
mail.send(to=to, subject="Email Code", html=html_content)
end_at = time.perf_counter()

View File

@ -0,0 +1,33 @@
import logging
import time
import click
from celery import shared_task # type: ignore
from flask import render_template_string
from extensions.ext_mail import mail
@shared_task(queue="mail")
def send_enterprise_email_task(to, subject, body, substitutions):
if not mail.is_inited():
return
logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green"))
start_at = time.perf_counter()
try:
html_content = render_template_string(body, **substitutions)
if isinstance(to, list):
for t in to:
mail.send(to=t, subject=subject, html=html_content)
else:
mail.send(to=to, subject=subject, html=html_content)
end_at = time.perf_counter()
logging.info(
click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green")
)
except Exception:
logging.exception("Send enterprise mail to {} failed".format(to))

View File

@ -7,6 +7,7 @@ from flask import render_template
from configs import dify_config
from extensions.ext_mail import mail
from services.feature_service import FeatureService
@shared_task(queue="mail")
@ -33,23 +34,45 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam
try:
url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}"
if language == "zh-Hans":
html_content = render_template(
"invite_member_mail_template_zh-CN.html",
to=to,
inviter_name=inviter_name,
workspace_name=workspace_name,
url=url,
)
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
template = "invite_member_mail_template_zh-CN.html"
system_features = FeatureService.get_system_features()
if system_features.branding.enabled:
application_title = system_features.branding.application_title
template = "without-brand/invite_member_mail_template_zh-CN.html"
html_content = render_template(
template,
to=to,
inviter_name=inviter_name,
workspace_name=workspace_name,
url=url,
application_title=application_title,
)
mail.send(to=to, subject=f"立即加入 {application_title} 工作空间", html=html_content)
else:
html_content = render_template(
template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url
)
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
else:
html_content = render_template(
"invite_member_mail_template_en-US.html",
to=to,
inviter_name=inviter_name,
workspace_name=workspace_name,
url=url,
)
mail.send(to=to, subject="Join Dify Workspace Now", html=html_content)
template = "invite_member_mail_template_en-US.html"
system_features = FeatureService.get_system_features()
if system_features.branding.enabled:
application_title = system_features.branding.application_title
template = "without-brand/invite_member_mail_template_en-US.html"
html_content = render_template(
template,
to=to,
inviter_name=inviter_name,
workspace_name=workspace_name,
url=url,
application_title=application_title,
)
mail.send(to=to, subject=f"Join {application_title} Workspace Now", html=html_content)
else:
html_content = render_template(
template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url
)
mail.send(to=to, subject="Join Dify Workspace Now", html=html_content)
end_at = time.perf_counter()
logging.info(

Some files were not shown because too many files have changed in this diff Show More