fix: Copy request context and current user in app generators. (#20240)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-05-27 10:56:23 +08:00 committed by GitHub
parent acd4b9a8ac
commit b357eca307
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 186 additions and 99 deletions

View File

@ -11,10 +11,6 @@ if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
""" """
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
""" """

View File

@ -3,7 +3,7 @@ from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.annotation_fields import ( from fields.annotation_fields import (
annotation_fields, annotation_fields,
@ -14,7 +14,7 @@ from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @validate_app_token
def post(self, app_model: App, end_user: EndUser, action): def post(self, app_model: App, end_user: EndUser, action):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument("score_threshold", required=True, type=float, location="json")
@ -31,7 +31,7 @@ class AnnotationReplyActionApi(Resource):
class AnnotationReplyActionStatusApi(Resource): class AnnotationReplyActionStatusApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @validate_app_token
def get(self, app_model: App, end_user: EndUser, job_id, action): def get(self, app_model: App, end_user: EndUser, job_id, action):
job_id = str(job_id) job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
@ -49,7 +49,7 @@ class AnnotationReplyActionStatusApi(Resource):
class AnnotationListApi(Resource): class AnnotationListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @validate_app_token
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
@ -65,7 +65,7 @@ class AnnotationListApi(Resource):
} }
return response, 200 return response, 200
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @validate_app_token
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser): def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -77,7 +77,7 @@ class AnnotationListApi(Resource):
class AnnotationUpdateDeleteApi(Resource): class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @validate_app_token
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def put(self, app_model: App, end_user: EndUser, annotation_id): def put(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor: if not current_user.is_editor:
@ -91,7 +91,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation return annotation
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @validate_app_token
def delete(self, app_model: App, end_user: EndUser, annotation_id): def delete(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()

View File

@ -99,7 +99,12 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if user_id: if user_id:
user_id = str(user_id) user_id = str(user_id)
kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) end_user = create_or_update_end_user_for_user_id(app_model, user_id)
kwargs["end_user"] = end_user
# Set EndUser as current logged-in user for flask_login.current_user
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
return view_func(*args, **kwargs) return view_func(*args, **kwargs)

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, copy_current_request_context, current_app, has_request_context
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -158,7 +158,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
trace_manager=trace_manager, trace_manager=trace_manager,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -240,7 +239,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id=node_id, inputs=args["inputs"] node_id=node_id, inputs=args["inputs"]
), ),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -316,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras={"auto_generate_conversation_name": False}, extras={"auto_generate_conversation_name": False},
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -399,18 +396,23 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id, message_id=message.id,
) )
# new thread # new thread with request context and contextvars
worker_thread = threading.Thread( context = contextvars.copy_context()
target=self._generate_worker,
kwargs={ @copy_current_request_context
"flask_app": current_app._get_current_object(), # type: ignore def worker_with_context():
"application_generate_entity": application_generate_entity, # Run the worker within the copied context
"queue_manager": queue_manager, return context.run(
"conversation_id": conversation.id, self._generate_worker,
"message_id": message.id, flask_app=current_app._get_current_object(), # type: ignore
"context": contextvars.copy_context(), application_generate_entity=application_generate_entity,
}, queue_manager=queue_manager,
) conversation_id=conversation.id,
message_id=message.id,
context=context,
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()
@ -449,8 +451,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
""" """
for var, val in context.items(): for var, val in context.items():
var.set(val) var.set(val)
# Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context(): with flask_app.app_context():
try: try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, copy_current_request_context, current_app, has_request_context
from pydantic import ValidationError from pydantic import ValidationError
from configs import dify_config from configs import dify_config
@ -179,18 +179,23 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id, message_id=message.id,
) )
# new thread # new thread with request context and contextvars
worker_thread = threading.Thread( context = contextvars.copy_context()
target=self._generate_worker,
kwargs={ @copy_current_request_context
"flask_app": current_app._get_current_object(), # type: ignore def worker_with_context():
"context": contextvars.copy_context(), # Run the worker within the copied context
"application_generate_entity": application_generate_entity, return context.run(
"queue_manager": queue_manager, self._generate_worker,
"conversation_id": conversation.id, flask_app=current_app._get_current_object(), # type: ignore
"message_id": message.id, context=context,
}, application_generate_entity=application_generate_entity,
) queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()
@ -227,8 +232,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
for var, val in context.items(): for var, val in context.items():
var.set(val) var.set(val)
# Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context(): with flask_app.app_context():
try: try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)

View File

@ -4,7 +4,7 @@ import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError from pydantic import ValidationError
from configs import dify_config from configs import dify_config
@ -170,17 +170,18 @@ class ChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id, message_id=message.id,
) )
# new thread # new thread with request context
worker_thread = threading.Thread( @copy_current_request_context
target=self._generate_worker, def worker_with_context():
kwargs={ return self._generate_worker(
"flask_app": current_app._get_current_object(), # type: ignore flask_app=current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity, application_generate_entity=application_generate_entity,
"queue_manager": queue_manager, queue_manager=queue_manager,
"conversation_id": conversation.id, conversation_id=conversation.id,
"message_id": message.id, message_id=message.id,
}, )
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()

View File

@ -4,7 +4,7 @@ import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError from pydantic import ValidationError
from configs import dify_config from configs import dify_config
@ -151,16 +151,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
message_id=message.id, message_id=message.id,
) )
# new thread # new thread with request context
worker_thread = threading.Thread( @copy_current_request_context
target=self._generate_worker, def worker_with_context():
kwargs={ return self._generate_worker(
"flask_app": current_app._get_current_object(), # type: ignore flask_app=current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity, application_generate_entity=application_generate_entity,
"queue_manager": queue_manager, queue_manager=queue_manager,
"message_id": message.id, message_id=message.id,
}, )
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()
@ -313,16 +314,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
message_id=message.id, message_id=message.id,
) )
# new thread # new thread with request context
worker_thread = threading.Thread( @copy_current_request_context
target=self._generate_worker, def worker_with_context():
kwargs={ return self._generate_worker(
"flask_app": current_app._get_current_object(), # type: ignore flask_app=current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity, application_generate_entity=application_generate_entity,
"queue_manager": queue_manager, queue_manager=queue_manager,
"message_id": message.id, message_id=message.id,
}, )
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Optional, Union, overload from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, copy_current_request_context, current_app, has_request_context
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -135,7 +135,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -207,17 +206,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_mode=app_model.mode, app_mode=app_model.mode,
) )
# new thread # new thread with request context and contextvars
worker_thread = threading.Thread( context = contextvars.copy_context()
target=self._generate_worker,
kwargs={ @copy_current_request_context
"flask_app": current_app._get_current_object(), # type: ignore def worker_with_context():
"application_generate_entity": application_generate_entity, # Run the worker within the copied context
"queue_manager": queue_manager, return context.run(
"context": contextvars.copy_context(), self._generate_worker,
"workflow_thread_pool_id": workflow_thread_pool_id, flask_app=current_app._get_current_object(), # type: ignore
}, application_generate_entity=application_generate_entity,
) queue_manager=queue_manager,
context=context,
workflow_thread_pool_id=workflow_thread_pool_id,
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()
@ -277,7 +281,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
), ),
workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -354,7 +357,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -408,8 +410,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
""" """
for var, val in context.items(): for var, val in context.items():
var.set(val) var.set(val)
# Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context(): with flask_app.app_context():
try: try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# workflow app # workflow app
runner = WorkflowAppRunner( runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,

View File

@ -5,7 +5,6 @@ from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
import contexts
from configs import dify_config from configs import dify_config
from dify_app import DifyApp from dify_app import DifyApp
from extensions.ext_database import db from extensions.ext_database import db
@ -82,8 +81,8 @@ def on_user_logged_in(_sender, user):
Note: AccountService.load_logged_in_account will populate user.current_tenant_id Note: AccountService.load_logged_in_account will populate user.current_tenant_id
through the load_user method, which calls account.set_tenant_id(). through the load_user method, which calls account.set_tenant_id().
""" """
if user and isinstance(user, Account) and user.current_tenant_id: # tenant_id context variable removed - using current_user.current_tenant_id directly
contexts.tenant_id.set(user.current_tenant_id) pass
@login_manager.unauthorized_handler @login_manager.unauthorized_handler

View File

@ -6,6 +6,8 @@ from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4 from uuid import uuid4
from flask_login import current_user
from core.variables import utils as variable_utils from core.variables import utils as variable_utils
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment from factories.variable_factory import build_segment
@ -17,7 +19,6 @@ import sqlalchemy as sa
from sqlalchemy import UniqueConstraint, func from sqlalchemy import UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
import contexts
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable from core.variables import SecretVariable, Segment, SegmentType, Variable
@ -274,7 +275,16 @@ class Workflow(Base):
if self._environment_variables is None: if self._environment_variables is None:
self._environment_variables = "{}" self._environment_variables = "{}"
tenant_id = contexts.tenant_id.get() # Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account):
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
results = [ results = [
@ -297,7 +307,17 @@ class Workflow(Base):
self._environment_variables = "{}" self._environment_variables = "{}"
return return
tenant_id = contexts.tenant_id.get() # Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account):
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id:
self._environment_variables = "{}"
return
value = list(value) value = list(value)
if any(var for var in value if not var.id): if any(var for var in value if not var.id):

View File

@ -2,14 +2,13 @@ import json
from unittest import mock from unittest import mock
from uuid import uuid4 from uuid import uuid4
import contexts
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow, WorkflowNodeExecution from models.workflow import Workflow, WorkflowNodeExecution
def test_environment_variables(): def test_environment_variables():
contexts.tenant_id.set("tenant_id") # tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance # Create a Workflow instance
workflow = Workflow( workflow = Workflow(
@ -38,9 +37,14 @@ def test_environment_variables():
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
) )
# Mock current_user as an EndUser
mock_user = mock.Mock()
mock_user.tenant_id = "tenant_id"
with ( with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
mock.patch("models.workflow.current_user", mock_user),
): ):
# Set the environment_variables property of the Workflow instance # Set the environment_variables property of the Workflow instance
variables = [variable1, variable2, variable3, variable4] variables = [variable1, variable2, variable3, variable4]
@ -51,7 +55,7 @@ def test_environment_variables():
def test_update_environment_variables(): def test_update_environment_variables():
contexts.tenant_id.set("tenant_id") # tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance # Create a Workflow instance
workflow = Workflow( workflow = Workflow(
@ -80,9 +84,14 @@ def test_update_environment_variables():
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
) )
# Mock current_user as an EndUser
mock_user = mock.Mock()
mock_user.tenant_id = "tenant_id"
with ( with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
mock.patch("models.workflow.current_user", mock_user),
): ):
variables = [variable1, variable2, variable3, variable4] variables = [variable1, variable2, variable3, variable4]
@ -104,7 +113,7 @@ def test_update_environment_variables():
def test_to_dict(): def test_to_dict():
contexts.tenant_id.set("tenant_id") # tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance # Create a Workflow instance
workflow = Workflow( workflow = Workflow(
@ -121,9 +130,14 @@ def test_to_dict():
# Create some EnvironmentVariable instances # Create some EnvironmentVariable instances
# Mock current_user as an EndUser
mock_user = mock.Mock()
mock_user.tenant_id = "tenant_id"
with ( with (
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
mock.patch("models.workflow.current_user", mock_user),
): ):
# Set the environment_variables property of the Workflow instance # Set the environment_variables property of the Workflow instance
workflow.environment_variables = [ workflow.environment_variables = [