mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 08:28:57 +08:00
fix: Copy request context and current user in app generators. (#20240)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
acd4b9a8ac
commit
b357eca307
@ -11,10 +11,6 @@ if TYPE_CHECKING:
|
|||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
|
||||||
|
|
||||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
|
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
|
||||||
"""
|
"""
|
||||||
|
@ -3,7 +3,7 @@ from flask_restful import Resource, marshal, marshal_with, reqparse
|
|||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
from controllers.service_api.wraps import validate_app_token
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.annotation_fields import (
|
from fields.annotation_fields import (
|
||||||
annotation_fields,
|
annotation_fields,
|
||||||
@ -14,7 +14,7 @@ from services.annotation_service import AppAnnotationService
|
|||||||
|
|
||||||
|
|
||||||
class AnnotationReplyActionApi(Resource):
|
class AnnotationReplyActionApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
@validate_app_token
|
||||||
def post(self, app_model: App, end_user: EndUser, action):
|
def post(self, app_model: App, end_user: EndUser, action):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||||
@ -31,7 +31,7 @@ class AnnotationReplyActionApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class AnnotationReplyActionStatusApi(Resource):
|
class AnnotationReplyActionStatusApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
@validate_app_token
|
||||||
def get(self, app_model: App, end_user: EndUser, job_id, action):
|
def get(self, app_model: App, end_user: EndUser, job_id, action):
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||||
@ -49,7 +49,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class AnnotationListApi(Resource):
|
class AnnotationListApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
@validate_app_token
|
||||||
def get(self, app_model: App, end_user: EndUser):
|
def get(self, app_model: App, end_user: EndUser):
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
@ -65,7 +65,7 @@ class AnnotationListApi(Resource):
|
|||||||
}
|
}
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
@validate_app_token
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -77,7 +77,7 @@ class AnnotationListApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
class AnnotationUpdateDeleteApi(Resource):
|
class AnnotationUpdateDeleteApi(Resource):
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
@validate_app_token
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def put(self, app_model: App, end_user: EndUser, annotation_id):
|
def put(self, app_model: App, end_user: EndUser, annotation_id):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
@ -91,7 +91,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
@validate_app_token
|
||||||
def delete(self, app_model: App, end_user: EndUser, annotation_id):
|
def delete(self, app_model: App, end_user: EndUser, annotation_id):
|
||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
@ -99,7 +99,12 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
|||||||
if user_id:
|
if user_id:
|
||||||
user_id = str(user_id)
|
user_id = str(user_id)
|
||||||
|
|
||||||
kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id)
|
end_user = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||||
|
kwargs["end_user"] = end_user
|
||||||
|
|
||||||
|
# Set EndUser as current logged-in user for flask_login.current_user
|
||||||
|
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
|
||||||
|
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
|
||||||
|
|
||||||
return view_func(*args, **kwargs)
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Optional, Union, overload
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@ -158,7 +158,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
)
|
)
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
@ -240,7 +239,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
node_id=node_id, inputs=args["inputs"]
|
node_id=node_id, inputs=args["inputs"]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
@ -316,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
extras={"auto_generate_conversation_name": False},
|
extras={"auto_generate_conversation_name": False},
|
||||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||||
)
|
)
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
@ -399,19 +396,24 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# new thread
|
# new thread with request context and contextvars
|
||||||
worker_thread = threading.Thread(
|
context = contextvars.copy_context()
|
||||||
target=self._generate_worker,
|
|
||||||
kwargs={
|
@copy_current_request_context
|
||||||
"flask_app": current_app._get_current_object(), # type: ignore
|
def worker_with_context():
|
||||||
"application_generate_entity": application_generate_entity,
|
# Run the worker within the copied context
|
||||||
"queue_manager": queue_manager,
|
return context.run(
|
||||||
"conversation_id": conversation.id,
|
self._generate_worker,
|
||||||
"message_id": message.id,
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
"context": contextvars.copy_context(),
|
application_generate_entity=application_generate_entity,
|
||||||
},
|
queue_manager=queue_manager,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
message_id=message.id,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
worker_thread = threading.Thread(target=worker_with_context)
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# return response or stream generator
|
# return response or stream generator
|
||||||
@ -449,8 +451,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
"""
|
"""
|
||||||
for var, val in context.items():
|
for var, val in context.items():
|
||||||
var.set(val)
|
var.set(val)
|
||||||
|
|
||||||
|
# Save current user before entering new app context
|
||||||
|
from flask import g
|
||||||
|
|
||||||
|
saved_user = None
|
||||||
|
if has_request_context() and hasattr(g, "_login_user"):
|
||||||
|
saved_user = g._login_user
|
||||||
|
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
try:
|
try:
|
||||||
|
# Restore user in new app context
|
||||||
|
if saved_user is not None:
|
||||||
|
from flask import g
|
||||||
|
|
||||||
|
g._login_user = saved_user
|
||||||
|
|
||||||
# get conversation and message
|
# get conversation and message
|
||||||
conversation = self._get_conversation(conversation_id)
|
conversation = self._get_conversation(conversation_id)
|
||||||
message = self._get_message(message_id)
|
message = self._get_message(message_id)
|
||||||
|
@ -5,7 +5,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Union, overload
|
from typing import Any, Literal, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -179,19 +179,24 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# new thread
|
# new thread with request context and contextvars
|
||||||
worker_thread = threading.Thread(
|
context = contextvars.copy_context()
|
||||||
target=self._generate_worker,
|
|
||||||
kwargs={
|
@copy_current_request_context
|
||||||
"flask_app": current_app._get_current_object(), # type: ignore
|
def worker_with_context():
|
||||||
"context": contextvars.copy_context(),
|
# Run the worker within the copied context
|
||||||
"application_generate_entity": application_generate_entity,
|
return context.run(
|
||||||
"queue_manager": queue_manager,
|
self._generate_worker,
|
||||||
"conversation_id": conversation.id,
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
"message_id": message.id,
|
context=context,
|
||||||
},
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
worker_thread = threading.Thread(target=worker_with_context)
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# return response or stream generator
|
# return response or stream generator
|
||||||
@ -227,8 +232,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
for var, val in context.items():
|
for var, val in context.items():
|
||||||
var.set(val)
|
var.set(val)
|
||||||
|
|
||||||
|
# Save current user before entering new app context
|
||||||
|
from flask import g
|
||||||
|
|
||||||
|
saved_user = None
|
||||||
|
if has_request_context() and hasattr(g, "_login_user"):
|
||||||
|
saved_user = g._login_user
|
||||||
|
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
try:
|
try:
|
||||||
|
# Restore user in new app context
|
||||||
|
if saved_user is not None:
|
||||||
|
from flask import g
|
||||||
|
|
||||||
|
g._login_user = saved_user
|
||||||
|
|
||||||
# get conversation and message
|
# get conversation and message
|
||||||
conversation = self._get_conversation(conversation_id)
|
conversation = self._get_conversation(conversation_id)
|
||||||
message = self._get_message(message_id)
|
message = self._get_message(message_id)
|
||||||
|
@ -4,7 +4,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Union, overload
|
from typing import Any, Literal, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, copy_current_request_context, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -170,18 +170,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# new thread
|
# new thread with request context
|
||||||
worker_thread = threading.Thread(
|
@copy_current_request_context
|
||||||
target=self._generate_worker,
|
def worker_with_context():
|
||||||
kwargs={
|
return self._generate_worker(
|
||||||
"flask_app": current_app._get_current_object(), # type: ignore
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
"application_generate_entity": application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
"queue_manager": queue_manager,
|
queue_manager=queue_manager,
|
||||||
"conversation_id": conversation.id,
|
conversation_id=conversation.id,
|
||||||
"message_id": message.id,
|
message_id=message.id,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
worker_thread = threading.Thread(target=worker_with_context)
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# return response or stream generator
|
# return response or stream generator
|
||||||
|
@ -4,7 +4,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Union, overload
|
from typing import Any, Literal, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, copy_current_request_context, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -151,17 +151,18 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# new thread
|
# new thread with request context
|
||||||
worker_thread = threading.Thread(
|
@copy_current_request_context
|
||||||
target=self._generate_worker,
|
def worker_with_context():
|
||||||
kwargs={
|
return self._generate_worker(
|
||||||
"flask_app": current_app._get_current_object(), # type: ignore
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
"application_generate_entity": application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
"queue_manager": queue_manager,
|
queue_manager=queue_manager,
|
||||||
"message_id": message.id,
|
message_id=message.id,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
worker_thread = threading.Thread(target=worker_with_context)
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# return response or stream generator
|
# return response or stream generator
|
||||||
@ -313,17 +314,18 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# new thread
|
# new thread with request context
|
||||||
worker_thread = threading.Thread(
|
@copy_current_request_context
|
||||||
target=self._generate_worker,
|
def worker_with_context():
|
||||||
kwargs={
|
return self._generate_worker(
|
||||||
"flask_app": current_app._get_current_object(), # type: ignore
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
"application_generate_entity": application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
"queue_manager": queue_manager,
|
queue_manager=queue_manager,
|
||||||
"message_id": message.id,
|
message_id=message.id,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
worker_thread = threading.Thread(target=worker_with_context)
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# return response or stream generator
|
# return response or stream generator
|
||||||
|
@ -5,7 +5,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, Union, overload
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@ -135,7 +135,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
@ -207,18 +206,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
app_mode=app_model.mode,
|
app_mode=app_model.mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# new thread
|
# new thread with request context and contextvars
|
||||||
worker_thread = threading.Thread(
|
context = contextvars.copy_context()
|
||||||
target=self._generate_worker,
|
|
||||||
kwargs={
|
@copy_current_request_context
|
||||||
"flask_app": current_app._get_current_object(), # type: ignore
|
def worker_with_context():
|
||||||
"application_generate_entity": application_generate_entity,
|
# Run the worker within the copied context
|
||||||
"queue_manager": queue_manager,
|
return context.run(
|
||||||
"context": contextvars.copy_context(),
|
self._generate_worker,
|
||||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
},
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
context=context,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
worker_thread = threading.Thread(target=worker_with_context)
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# return response or stream generator
|
# return response or stream generator
|
||||||
@ -277,7 +281,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
),
|
),
|
||||||
workflow_run_id=str(uuid.uuid4()),
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
@ -354,7 +357,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||||
workflow_run_id=str(uuid.uuid4()),
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
@ -408,8 +410,22 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
"""
|
"""
|
||||||
for var, val in context.items():
|
for var, val in context.items():
|
||||||
var.set(val)
|
var.set(val)
|
||||||
|
|
||||||
|
# Save current user before entering new app context
|
||||||
|
from flask import g
|
||||||
|
|
||||||
|
saved_user = None
|
||||||
|
if has_request_context() and hasattr(g, "_login_user"):
|
||||||
|
saved_user = g._login_user
|
||||||
|
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
try:
|
try:
|
||||||
|
# Restore user in new app context
|
||||||
|
if saved_user is not None:
|
||||||
|
from flask import g
|
||||||
|
|
||||||
|
g._login_user = saved_user
|
||||||
|
|
||||||
# workflow app
|
# workflow app
|
||||||
runner = WorkflowAppRunner(
|
runner = WorkflowAppRunner(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
|
@ -5,7 +5,6 @@ from flask import Response, request
|
|||||||
from flask_login import user_loaded_from_request, user_logged_in
|
from flask_login import user_loaded_from_request, user_logged_in
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
import contexts
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -82,8 +81,8 @@ def on_user_logged_in(_sender, user):
|
|||||||
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
||||||
through the load_user method, which calls account.set_tenant_id().
|
through the load_user method, which calls account.set_tenant_id().
|
||||||
"""
|
"""
|
||||||
if user and isinstance(user, Account) and user.current_tenant_id:
|
# tenant_id context variable removed - using current_user.current_tenant_id directly
|
||||||
contexts.tenant_id.set(user.current_tenant_id)
|
pass
|
||||||
|
|
||||||
|
|
||||||
@login_manager.unauthorized_handler
|
@login_manager.unauthorized_handler
|
||||||
|
@ -6,6 +6,8 @@ from enum import Enum, StrEnum
|
|||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
|
||||||
from core.variables import utils as variable_utils
|
from core.variables import utils as variable_utils
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from factories.variable_factory import build_segment
|
from factories.variable_factory import build_segment
|
||||||
@ -17,7 +19,6 @@ import sqlalchemy as sa
|
|||||||
from sqlalchemy import UniqueConstraint, func
|
from sqlalchemy import UniqueConstraint, func
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
import contexts
|
|
||||||
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.variables import SecretVariable, Segment, SegmentType, Variable
|
from core.variables import SecretVariable, Segment, SegmentType, Variable
|
||||||
@ -274,7 +275,16 @@ class Workflow(Base):
|
|||||||
if self._environment_variables is None:
|
if self._environment_variables is None:
|
||||||
self._environment_variables = "{}"
|
self._environment_variables = "{}"
|
||||||
|
|
||||||
tenant_id = contexts.tenant_id.get()
|
# Get tenant_id from current_user (Account or EndUser)
|
||||||
|
if isinstance(current_user, Account):
|
||||||
|
# Account user
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
else:
|
||||||
|
# EndUser
|
||||||
|
tenant_id = current_user.tenant_id
|
||||||
|
|
||||||
|
if not tenant_id:
|
||||||
|
return []
|
||||||
|
|
||||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
|
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
|
||||||
results = [
|
results = [
|
||||||
@ -297,7 +307,17 @@ class Workflow(Base):
|
|||||||
self._environment_variables = "{}"
|
self._environment_variables = "{}"
|
||||||
return
|
return
|
||||||
|
|
||||||
tenant_id = contexts.tenant_id.get()
|
# Get tenant_id from current_user (Account or EndUser)
|
||||||
|
if isinstance(current_user, Account):
|
||||||
|
# Account user
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
else:
|
||||||
|
# EndUser
|
||||||
|
tenant_id = current_user.tenant_id
|
||||||
|
|
||||||
|
if not tenant_id:
|
||||||
|
self._environment_variables = "{}"
|
||||||
|
return
|
||||||
|
|
||||||
value = list(value)
|
value = list(value)
|
||||||
if any(var for var in value if not var.id):
|
if any(var for var in value if not var.id):
|
||||||
|
@ -2,14 +2,13 @@ import json
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import contexts
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||||
from models.workflow import Workflow, WorkflowNodeExecution
|
from models.workflow import Workflow, WorkflowNodeExecution
|
||||||
|
|
||||||
|
|
||||||
def test_environment_variables():
|
def test_environment_variables():
|
||||||
contexts.tenant_id.set("tenant_id")
|
# tenant_id context variable removed - using current_user.current_tenant_id directly
|
||||||
|
|
||||||
# Create a Workflow instance
|
# Create a Workflow instance
|
||||||
workflow = Workflow(
|
workflow = Workflow(
|
||||||
@ -38,9 +37,14 @@ def test_environment_variables():
|
|||||||
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
|
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock current_user as an EndUser
|
||||||
|
mock_user = mock.Mock()
|
||||||
|
mock_user.tenant_id = "tenant_id"
|
||||||
|
|
||||||
with (
|
with (
|
||||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||||
|
mock.patch("models.workflow.current_user", mock_user),
|
||||||
):
|
):
|
||||||
# Set the environment_variables property of the Workflow instance
|
# Set the environment_variables property of the Workflow instance
|
||||||
variables = [variable1, variable2, variable3, variable4]
|
variables = [variable1, variable2, variable3, variable4]
|
||||||
@ -51,7 +55,7 @@ def test_environment_variables():
|
|||||||
|
|
||||||
|
|
||||||
def test_update_environment_variables():
|
def test_update_environment_variables():
|
||||||
contexts.tenant_id.set("tenant_id")
|
# tenant_id context variable removed - using current_user.current_tenant_id directly
|
||||||
|
|
||||||
# Create a Workflow instance
|
# Create a Workflow instance
|
||||||
workflow = Workflow(
|
workflow = Workflow(
|
||||||
@ -80,9 +84,14 @@ def test_update_environment_variables():
|
|||||||
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
|
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock current_user as an EndUser
|
||||||
|
mock_user = mock.Mock()
|
||||||
|
mock_user.tenant_id = "tenant_id"
|
||||||
|
|
||||||
with (
|
with (
|
||||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||||
|
mock.patch("models.workflow.current_user", mock_user),
|
||||||
):
|
):
|
||||||
variables = [variable1, variable2, variable3, variable4]
|
variables = [variable1, variable2, variable3, variable4]
|
||||||
|
|
||||||
@ -104,7 +113,7 @@ def test_update_environment_variables():
|
|||||||
|
|
||||||
|
|
||||||
def test_to_dict():
|
def test_to_dict():
|
||||||
contexts.tenant_id.set("tenant_id")
|
# tenant_id context variable removed - using current_user.current_tenant_id directly
|
||||||
|
|
||||||
# Create a Workflow instance
|
# Create a Workflow instance
|
||||||
workflow = Workflow(
|
workflow = Workflow(
|
||||||
@ -121,9 +130,14 @@ def test_to_dict():
|
|||||||
|
|
||||||
# Create some EnvironmentVariable instances
|
# Create some EnvironmentVariable instances
|
||||||
|
|
||||||
|
# Mock current_user as an EndUser
|
||||||
|
mock_user = mock.Mock()
|
||||||
|
mock_user.tenant_id = "tenant_id"
|
||||||
|
|
||||||
with (
|
with (
|
||||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||||
|
mock.patch("models.workflow.current_user", mock_user),
|
||||||
):
|
):
|
||||||
# Set the environment_variables property of the Workflow instance
|
# Set the environment_variables property of the Workflow instance
|
||||||
workflow.environment_variables = [
|
workflow.environment_variables = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user