diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 127b8fe76d..ae41a2c03a 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -11,10 +11,6 @@ if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool -tenant_id: ContextVar[str] = ContextVar("tenant_id") - -workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") - """ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with """ diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index bd1a23b723..1a7f0c935b 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -3,7 +3,7 @@ from flask_restful import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.service_api import api -from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import ( annotation_fields, @@ -14,7 +14,7 @@ from services.annotation_service import AppAnnotationService class AnnotationReplyActionApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @validate_app_token def post(self, app_model: App, end_user: EndUser, action): parser = reqparse.RequestParser() parser.add_argument("score_threshold", required=True, type=float, location="json") @@ -31,7 +31,7 @@ class AnnotationReplyActionApi(Resource): class AnnotationReplyActionStatusApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + @validate_app_token def get(self, app_model: App, end_user: EndUser, job_id, action): job_id = str(job_id) app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) @@ -49,7 +49,7 @@ class AnnotationReplyActionStatusApi(Resource): class AnnotationListApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + @validate_app_token def get(self, app_model: App, end_user: EndUser): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) @@ -65,7 +65,7 @@ class AnnotationListApi(Resource): } return response, 200 - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @validate_app_token @marshal_with(annotation_fields) def post(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() @@ -77,7 +77,7 @@ class AnnotationListApi(Resource): class AnnotationUpdateDeleteApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @validate_app_token @marshal_with(annotation_fields) def put(self, app_model: App, end_user: EndUser, annotation_id): if not current_user.is_editor: @@ -91,7 +91,7 @@ class AnnotationUpdateDeleteApi(Resource): annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) return annotation - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + @validate_app_token def delete(self, app_model: App, end_user: EndUser, annotation_id): if not current_user.is_editor: raise Forbidden() diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index cd35ceac1d..d3316a5159 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -99,7 +99,12 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if user_id: user_id = str(user_id) - kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) + end_user = create_or_update_end_user_for_user_id(app_model, user_id) + kwargs["end_user"] = end_user + + # Set EndUser as current logged-in user for flask_login.current_user + current_app.login_manager._update_request_context_with_user(end_user) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore return view_func(*args, **kwargs) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 4b021aa0b3..50a843c7e8 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Optional, Union, overload -from flask import Flask, current_app +from flask import Flask, copy_current_request_context, current_app, has_request_context from pydantic import ValidationError from sqlalchemy.orm import sessionmaker @@ -158,7 +158,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): trace_manager=trace_manager, workflow_run_id=workflow_run_id, ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -240,7 +239,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): node_id=node_id, inputs=args["inputs"] ), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -316,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras={"auto_generate_conversation_name": False}, single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -399,18 +396,23 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - "context": contextvars.copy_context(), - }, - ) + # new thread with request context and contextvars + context = contextvars.copy_context() + + @copy_current_request_context + def worker_with_context(): + # Run the worker within the copied context + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + context=context, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() @@ -449,8 +451,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): """ for var, val in context.items(): var.set(val) + + # Save current user before entering new app context + from flask import g + + saved_user = None + if has_request_context() and hasattr(g, "_login_user"): + saved_user = g._login_user + with flask_app.app_context(): try: + # Restore user in new app context + if saved_user is not None: + from flask import g + + g._login_user = saved_user + # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 3ed436c07a..d5ee324360 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload -from flask import Flask, current_app +from flask import Flask, copy_current_request_context, current_app, has_request_context from pydantic import ValidationError from configs import dify_config @@ -179,18 +179,23 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "context": contextvars.copy_context(), - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - }, - ) + # new thread with request context and contextvars + context = contextvars.copy_context() + + @copy_current_request_context + def worker_with_context(): + # Run the worker within the copied context + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + context=context, + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() @@ -227,8 +232,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): for var, val in context.items(): var.set(val) + # Save current user before entering new app context + from flask import g + + saved_user = None + if has_request_context() and hasattr(g, "_login_user"): + saved_user = g._login_user + with flask_app.app_context(): try: + # Restore user in new app context + if saved_user is not None: + from flask import g + + g._login_user = saved_user + # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 2d865795d8..a1329cb938 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -4,7 +4,7 @@ import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload -from flask import Flask, current_app +from flask import Flask, copy_current_request_context, current_app from pydantic import ValidationError from configs import dify_config @@ -170,17 +170,18 @@ class ChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - }, - ) + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return self._generate_worker( + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index b1bc412616..adcbaad3ec 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -4,7 +4,7 @@ import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload -from flask import Flask, current_app +from flask import Flask, copy_current_request_context, current_app from pydantic import ValidationError from configs import dify_config @@ -151,16 +151,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "message_id": message.id, - }, - ) + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return self._generate_worker( + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() @@ -313,16 +314,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "message_id": message.id, - }, - ) + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return self._generate_worker( + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 01d35c57ce..5442927413 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Generator, Mapping, Sequence from typing import Any, Literal, Optional, Union, overload -from flask import Flask, current_app +from flask import Flask, copy_current_request_context, current_app, has_request_context from pydantic import ValidationError from sqlalchemy.orm import sessionmaker @@ -135,7 +135,6 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_run_id=workflow_run_id, ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -207,17 +206,22 @@ class WorkflowAppGenerator(BaseAppGenerator): app_mode=app_model.mode, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": contextvars.copy_context(), - "workflow_thread_pool_id": workflow_thread_pool_id, - }, - ) + # new thread with request context and contextvars + context = contextvars.copy_context() + + @copy_current_request_context + def worker_with_context(): + # Run the worker within the copied context + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + context=context, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() @@ -277,7 +281,6 @@ class WorkflowAppGenerator(BaseAppGenerator): ), workflow_run_id=str(uuid.uuid4()), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -354,7 +357,6 @@ class WorkflowAppGenerator(BaseAppGenerator): single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), workflow_run_id=str(uuid.uuid4()), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -408,8 +410,22 @@ class WorkflowAppGenerator(BaseAppGenerator): """ for var, val in context.items(): var.set(val) + + # Save current user before entering new app context + from flask import g + + saved_user = None + if has_request_context() and hasattr(g, "_login_user"): + saved_user = g._login_user + with flask_app.app_context(): try: + # Restore user in new app context + if saved_user is not None: + from flask import g + + g._login_user = saved_user + # workflow app runner = WorkflowAppRunner( application_generate_entity=application_generate_entity, diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 6a5721c021..06f42494ec 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -5,7 +5,6 @@ from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import NotFound, Unauthorized -import contexts from configs import dify_config from dify_app import DifyApp from extensions.ext_database import db @@ -82,8 +81,8 @@ def on_user_logged_in(_sender, user): Note: AccountService.load_logged_in_account will populate user.current_tenant_id through the load_user method, which calls account.set_tenant_id(). """ - if user and isinstance(user, Account) and user.current_tenant_id: - contexts.tenant_id.set(user.current_tenant_id) + # tenant_id context variable removed - using current_user.current_tenant_id directly + pass @login_manager.unauthorized_handler diff --git a/api/models/workflow.py b/api/models/workflow.py index b0cb8dccd9..ae341dd1b5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -6,6 +6,8 @@ from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 +from flask_login import current_user + from core.variables import utils as variable_utils from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment @@ -17,7 +19,6 @@ import sqlalchemy as sa from sqlalchemy import UniqueConstraint, func from sqlalchemy.orm import Mapped, mapped_column -import contexts from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from core.variables import SecretVariable, Segment, SegmentType, Variable @@ -274,7 +275,16 @@ class Workflow(Base): if self._environment_variables is None: self._environment_variables = "{}" - tenant_id = contexts.tenant_id.get() + # Get tenant_id from current_user (Account or EndUser) + if isinstance(current_user, Account): + # Account user + tenant_id = current_user.current_tenant_id + else: + # EndUser + tenant_id = current_user.tenant_id + + if not tenant_id: + return [] environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) results = [ @@ -297,7 +307,17 @@ class Workflow(Base): self._environment_variables = "{}" return - tenant_id = contexts.tenant_id.get() + # Get tenant_id from current_user (Account or EndUser) + if isinstance(current_user, Account): + # Account user + tenant_id = current_user.current_tenant_id + else: + # EndUser + tenant_id = current_user.tenant_id + + if not tenant_id: + self._environment_variables = "{}" + return value = list(value) if any(var for var in value if not var.id): diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 70ce224eb6..34802d47a7 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -2,14 +2,13 @@ import json from unittest import mock from uuid import uuid4 -import contexts from constants import HIDDEN_VALUE from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from models.workflow import Workflow, WorkflowNodeExecution def test_environment_variables(): - contexts.tenant_id.set("tenant_id") + # tenant_id context variable removed - using current_user.current_tenant_id directly # Create a Workflow instance workflow = Workflow( @@ -38,9 +37,14 @@ def test_environment_variables(): {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} ) + # Mock current_user as an EndUser + mock_user = mock.Mock() + mock_user.tenant_id = "tenant_id" + with ( mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), + mock.patch("models.workflow.current_user", mock_user), ): # Set the environment_variables property of the Workflow instance variables = [variable1, variable2, variable3, variable4] @@ -51,7 +55,7 @@ def test_environment_variables(): def test_update_environment_variables(): - contexts.tenant_id.set("tenant_id") + # tenant_id context variable removed - using current_user.current_tenant_id directly # Create a Workflow instance workflow = Workflow( @@ -80,9 +84,14 @@ def test_update_environment_variables(): {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} ) + # Mock current_user as an EndUser + mock_user = mock.Mock() + mock_user.tenant_id = "tenant_id" + with ( mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), + mock.patch("models.workflow.current_user", mock_user), ): variables = [variable1, variable2, variable3, variable4] @@ -104,7 +113,7 @@ def test_update_environment_variables(): def test_to_dict(): - contexts.tenant_id.set("tenant_id") + # tenant_id context variable removed - using current_user.current_tenant_id directly # Create a Workflow instance workflow = Workflow( @@ -121,9 +130,14 @@ def test_to_dict(): # Create some EnvironmentVariable instances + # Mock current_user as an EndUser + mock_user = mock.Mock() + mock_user.tenant_id = "tenant_id" + with ( mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), + mock.patch("models.workflow.current_user", mock_user), ): # Set the environment_variables property of the Workflow instance workflow.environment_variables = [