diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 207f009eed..f9f7c7d78a 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -9,9 +9,9 @@ from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 -from flask import current_app from httpx import get +from configs import dify_config from extensions.ext_database import db from extensions.ext_storage import storage from models.model import MessageFile @@ -26,25 +26,25 @@ class ToolFileManager: """ sign file to get a temporary url """ - base_url = current_app.config.get('FILES_URL') + base_url = dify_config.FILES_URL file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}' timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() + data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}' + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}' @staticmethod def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature """ - data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() + data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}' + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() @@ -53,23 +53,23 @@ class ToolFileManager: return False current_time = int(time.time()) - return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT') + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT @staticmethod - def create_file_by_raw(user_id: str, tenant_id: str, - conversation_id: Optional[str], file_binary: bytes, - mimetype: str - ) -> ToolFile: + def create_file_by_raw( + user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str + ) -> ToolFile: """ create file """ extension = guess_extension(mimetype) or '.bin' unique_name = uuid4().hex - filename = f"tools/{tenant_id}/{unique_name}{extension}" + filename = f'tools/{tenant_id}/{unique_name}{extension}' storage.save(filename, file_binary) - tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, - conversation_id=conversation_id, file_key=filename, mimetype=mimetype) + tool_file = ToolFile( + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype + ) db.session.add(tool_file) db.session.commit() @@ -77,9 +77,12 @@ class ToolFileManager: return tool_file @staticmethod - def create_file_by_url(user_id: str, tenant_id: str, - conversation_id: str, file_url: str, - ) -> ToolFile: + def create_file_by_url( + user_id: str, + tenant_id: str, + conversation_id: str, + file_url: str, + ) -> ToolFile: """ create file """ @@ -90,12 +93,17 @@ class ToolFileManager: mimetype = guess_type(file_url)[0] or 'octet/stream' extension = guess_extension(mimetype) or '.bin' unique_name = uuid4().hex - filename = f"tools/{tenant_id}/{unique_name}{extension}" + filename = f'tools/{tenant_id}/{unique_name}{extension}' storage.save(filename, blob) - tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, - conversation_id=conversation_id, file_key=filename, - mimetype=mimetype, original_url=file_url) + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filename, + mimetype=mimetype, + original_url=file_url, + ) db.session.add(tool_file) db.session.commit() @@ -103,15 +111,15 @@ class ToolFileManager: return tool_file @staticmethod - def create_file_by_key(user_id: str, tenant_id: str, - conversation_id: str, file_key: str, - mimetype: str - ) -> ToolFile: + def create_file_by_key( + user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str + ) -> ToolFile: """ create file """ - tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, - conversation_id=conversation_id, file_key=file_key, mimetype=mimetype) + tool_file = ToolFile( + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype + ) return tool_file @staticmethod @@ -123,9 +131,13 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile = db.session.query(ToolFile).filter( - ToolFile.id == id, - ).first() + tool_file: ToolFile = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == id, + ) + .first() + ) if not tool_file: return None @@ -143,18 +155,31 @@ class ToolFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile = db.session.query(MessageFile).filter( - MessageFile.id == id, - ).first() + message_file: MessageFile = ( + db.session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() + ) - # get tool file id - tool_file_id = message_file.url.split('/')[-1] - # trim extension - tool_file_id = tool_file_id.split('.')[0] + # Check if message_file is not None + if message_file is not None: + # get tool file id + tool_file_id = message_file.url.split('/')[-1] + # trim extension + tool_file_id = tool_file_id.split('.')[0] + else: + tool_file_id = None - tool_file: ToolFile = db.session.query(ToolFile).filter( - ToolFile.id == tool_file_id, - ).first() + + tool_file: ToolFile = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) if not tool_file: return None @@ -172,9 +197,13 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile = db.session.query(ToolFile).filter( - ToolFile.id == tool_file_id, - ).first() + tool_file: ToolFile = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) if not tool_file: return None diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index e30a905cbc..3342300eb4 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -6,8 +6,7 @@ from os import listdir, path from threading import Lock from typing import Any, Union -from flask import current_app - +from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -566,7 +565,7 @@ class ToolManager: provider_type = provider_type provider_id = provider_id if provider_type == 'builtin': - return (current_app.config.get("CONSOLE_API_URL") + return (dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" + provider_id + "/icon") @@ -594,4 +593,4 @@ class ToolManager: else: raise ValueError(f"provider type {provider_type} not found") -ToolManager.load_builtin_providers_cache() \ No newline at end of file +ToolManager.load_builtin_providers_cache() diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 22deafb8a3..e81bf684a9 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -2,8 +2,7 @@ import logging import time from typing import Optional, cast -from flask import current_app - +from configs import dify_config from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom @@ -118,7 +117,7 @@ class WorkflowEngineManager: if not isinstance(graph.get('edges'), list): raise ValueError('edges in workflow graph must be a list') - + # init variable pool if not variable_pool: variable_pool = VariablePool( @@ -126,7 +125,7 @@ class WorkflowEngineManager: user_inputs=user_inputs ) - workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH") + workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) @@ -177,8 +176,8 @@ class WorkflowEngineManager: predecessor_node: BaseNode = None current_iteration_node: BaseIterationNode = None has_entry_node = False - max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") - max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") + max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS + max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME while True: # get next node, multiple target nodes in the future next_node = self._get_next_overall_node( @@ -237,7 +236,7 @@ class WorkflowEngineManager: next_node_id = next_iteration # get next id next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks) - + if not next_node: break @@ -398,7 +397,7 @@ class WorkflowEngineManager: tenant_id=workflow.tenant_id, node_instance=node_instance ) - + # run node node_run_result = node_instance.run( variable_pool=variable_pool @@ -443,7 +442,7 @@ class WorkflowEngineManager: node_config = node else: raise ValueError('node id is not an iteration node') - + # init variable pool variable_pool = VariablePool( system_variables={}, @@ -452,7 +451,7 @@ class WorkflowEngineManager: # variable selector to variable mapping iteration_nested_nodes = [ - node for node in nodes + node for node in nodes if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id ] iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] @@ -475,13 +474,13 @@ class WorkflowEngineManager: # remove iteration variables variable_mapping = { - f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() + f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() if value[0] != node_id } # remove variable out from iteration variable_mapping = { - key: value for key, value in variable_mapping.items() + key: value for key, value in variable_mapping.items() if value[0] not in iteration_nested_node_ids } @@ -561,7 +560,7 @@ class WorkflowEngineManager: error=error ) - def _workflow_iteration_started(self, graph: dict, + def _workflow_iteration_started(self, graph: dict, current_iteration_node: BaseIterationNode, workflow_run_state: WorkflowRunState, predecessor_node_id: Optional[str] = None, @@ -600,7 +599,7 @@ class WorkflowEngineManager: def _workflow_iteration_next(self, graph: dict, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, + workflow_run_state: WorkflowRunState, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow iteration next @@ -629,9 +628,9 @@ class WorkflowEngineManager: for node in nodes: workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id')) - + def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, + workflow_run_state: WorkflowRunState, callbacks: list[BaseWorkflowCallback] = None) -> None: if callbacks: if isinstance(workflow_run_state.current_iteration_state, IterationState): @@ -684,7 +683,7 @@ class WorkflowEngineManager: callbacks=callbacks, workflow_call_depth=workflow_run_state.workflow_call_depth ) - + else: edges = graph.get('edges') source_node_id = predecessor_node.node_id @@ -738,9 +737,9 @@ class WorkflowEngineManager: callbacks=callbacks, workflow_call_depth=workflow_run_state.workflow_call_depth ) - - def _get_node(self, workflow_run_state: WorkflowRunState, - graph: dict, + + def _get_node(self, workflow_run_state: WorkflowRunState, + graph: dict, node_id: str, callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]: """ @@ -940,7 +939,7 @@ class WorkflowEngineManager: return new_value - def _mapping_user_inputs_to_variable_pool(self, + def _mapping_user_inputs_to_variable_pool(self, variable_mapping: dict, user_inputs: dict, variable_pool: VariablePool, @@ -988,4 +987,4 @@ class WorkflowEngineManager: node_id=variable_node_id, variable_key_list=variable_key_list, value=value - ) \ No newline at end of file + ) diff --git a/api/models/dataset.py b/api/models/dataset.py index b0e3702dd7..d0be005a15 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,10 +9,10 @@ import re import time from json import JSONDecodeError -from flask import current_app from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB +from configs import dify_config from core.rag.retrieval.retrival_methods import RetrievalMethod from extensions.ext_database import db from extensions.ext_storage import storage @@ -528,7 +528,7 @@ class DocumentSegment(db.Model): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() diff --git a/api/models/model.py b/api/models/model.py index 4d67272c1a..331bb91c29 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -4,10 +4,11 @@ import uuid from enum import Enum from typing import Optional -from flask import current_app, request +from flask import request from flask_login import UserMixin from sqlalchemy import Float, func, text +from configs import dify_config from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db @@ -111,7 +112,7 @@ class App(db.Model): @property def api_base_url(self): - return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] + return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip('/')) + '/v1' @property @@ -1113,7 +1114,7 @@ class Site(db.Model): @property def app_base_url(self): return ( - current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/')) + dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.host_url.rstrip('/')) class ApiToken(db.Model):