From 7b225a5ab05c99ecb49c16d2e4d60b802d41f6f5 Mon Sep 17 00:00:00 2001 From: Waffle <52460705+ox01024@users.noreply.github.com> Date: Fri, 12 Jul 2024 12:25:38 +0800 Subject: [PATCH] refactor(services/tasks): Swtich to dify_config witch Pydantic (#6203) --- .../rag/datasource/keyword/keyword_factory.py | 7 +++---- api/services/account_service.py | 10 +++++----- api/services/app_generate_service.py | 4 ++-- api/services/app_service.py | 4 ++-- api/services/dataset_service.py | 6 +++--- api/services/entities/model_provider_entities.py | 8 ++++---- api/services/feature_service.py | 12 ++++++------ api/services/file_service.py | 16 +++++++--------- api/services/recommended_app_service.py | 10 +++++----- api/services/tools/tools_transform_service.py | 5 ++--- api/services/workspace_service.py | 4 ++-- api/tasks/document_indexing_task.py | 4 ++-- api/tasks/duplicate_document_indexing_task.py | 4 ++-- api/tasks/mail_invite_member_task.py | 5 +++-- api/tasks/mail_reset_password_task.py | 5 +++-- 15 files changed, 51 insertions(+), 53 deletions(-) diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index beb3322aa6..6ac610f82b 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -1,7 +1,6 @@ from typing import Any -from flask import current_app - +from configs import dify_config from core.rag.datasource.keyword.jieba.jieba import Jieba from core.rag.datasource.keyword.keyword_base import BaseKeyword from core.rag.models.document import Document @@ -14,8 +13,8 @@ class Keyword: self._keyword_processor = self._init_keyword() def _init_keyword(self) -> BaseKeyword: - config = current_app.config - keyword_type = config.get('KEYWORD_STORE') + config = dify_config + keyword_type = config.KEYWORD_STORE if not keyword_type: raise ValueError("Keyword store must be specified.") diff --git a/api/services/account_service.py b/api/services/account_service.py index 3fd2b5c627..0bcbe8b2c0 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -6,10 +6,10 @@ from datetime import datetime, timedelta, timezone from hashlib import sha256 from typing import Any, Optional -from flask import current_app from sqlalchemy import func from werkzeug.exceptions import Unauthorized +from configs import dify_config from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created from extensions.ext_redis import redis_client @@ -80,7 +80,7 @@ class AccountService: payload = { "user_id": account.id, "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, - "iss": current_app.config['EDITION'], + "iss": dify_config.EDITION, "sub": 'Console API Passport', } @@ -524,7 +524,7 @@ class RegisterService: TenantService.create_owner_tenant_if_not_exist(account) dify_setup = DifySetup( - version=current_app.config['CURRENT_VERSION'] + version=dify_config.CURRENT_VERSION ) db.session.add(dify_setup) db.session.commit() @@ -559,7 +559,7 @@ class RegisterService: if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if current_app.config['EDITION'] != 'SELF_HOSTED': + if dify_config.EDITION != 'SELF_HOSTED': tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role='owner') @@ -623,7 +623,7 @@ class RegisterService: 'email': account.email, 'workspace_id': tenant.id, } - expiryHours = current_app.config['INVITE_EXPIRY_HOURS'] + expiryHours = dify_config.INVITE_EXPIRY_HOURS redis_client.setex( cls._get_invitation_token_key(token), expiryHours * 60 * 60, diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 3acd3becdb..e894570b97 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,6 +1,7 @@ from collections.abc import Generator from typing import Any, Union +from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator @@ -89,8 +90,7 @@ class AppGenerateService: def _get_max_active_requests(app_model: App) -> int: max_active_requests = app_model.max_active_requests if app_model.max_active_requests is None: - from flask import current_app - max_active_requests = int(current_app.config['APP_MAX_ACTIVE_REQUESTS']) + max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) return max_active_requests @classmethod diff --git a/api/services/app_service.py b/api/services/app_service.py index 03986db2ae..ca3c8d4fdc 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,10 +4,10 @@ from datetime import datetime, timezone from typing import cast import yaml -from flask import current_app from flask_login import current_user from flask_sqlalchemy.pagination import Pagination +from configs import dify_config from constants.model_template import default_app_templates from core.agent.entities import AgentToolEntity from core.app.features.rate_limiting import RateLimit @@ -446,7 +446,7 @@ class AppService: # get all tools tools = agent_config.get('tools', []) - url_prefix = (current_app.config.get("CONSOLE_API_URL") + url_prefix = (dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/") for tool in tools: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index fa0a1bbc58..fbaf44c9a4 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,10 +6,10 @@ import time import uuid from typing import Optional -from flask import current_app from flask_login import current_user from sqlalchemy import func +from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -650,7 +650,7 @@ class DocumentService: elif document_data["data_source"]["type"] == "website_crawl": website_info = document_data["data_source"]['info_list']['website_info_list'] count = len(website_info['urls']) - batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1028,7 +1028,7 @@ class DocumentService: elif document_data["data_source"]["type"] == "website_crawl": website_info = document_data["data_source"]['info_list']['website_info_list'] count = len(website_info['urls']) - batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 853172ea13..e5e4d7e235 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,9 +1,9 @@ from enum import Enum from typing import Optional -from flask import current_app from pydantic import BaseModel, ConfigDict +from configs import dify_config from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject @@ -67,7 +67,7 @@ class ProviderResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (current_app.config.get("CONSOLE_API_URL") + url_prefix = (dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}") if self.icon_small is not None: self.icon_small = I18nObject( @@ -96,7 +96,7 @@ class ProviderWithModelsResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (current_app.config.get("CONSOLE_API_URL") + url_prefix = (dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}") if self.icon_small is not None: self.icon_small = I18nObject( @@ -119,7 +119,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (current_app.config.get("CONSOLE_API_URL") + url_prefix = (dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}") if self.icon_small is not None: self.icon_small = I18nObject( diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 7375554156..83e675a9d2 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,6 +1,6 @@ -from flask import current_app from pydantic import BaseModel, ConfigDict +from configs import dify_config from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -51,7 +51,7 @@ class FeatureService: cls._fulfill_params_from_env(features) - if current_app.config['BILLING_ENABLED']: + if dify_config.BILLING_ENABLED: cls._fulfill_params_from_billing_api(features, tenant_id) return features @@ -60,16 +60,16 @@ class FeatureService: def get_system_features(cls) -> SystemFeatureModel: system_features = SystemFeatureModel() - if current_app.config['ENTERPRISE_ENABLED']: + if dify_config.ENTERPRISE_ENABLED: cls._fulfill_params_from_enterprise(system_features) return system_features @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): - features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] - features.model_load_balancing_enabled = current_app.config['MODEL_LB_ENABLED'] - features.dataset_operator_enabled = current_app.config['DATASET_OPERATOR_ENABLED'] + features.can_replace_logo = dify_config.CAN_REPLACE_LOGO + features.model_load_balancing_enabled = dify_config.MODEL_LB_ENABLED + features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED @classmethod def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): diff --git a/api/services/file_service.py b/api/services/file_service.py index 6c308a09df..c686b190fe 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -4,11 +4,11 @@ import uuid from collections.abc import Generator from typing import Union -from flask import current_app from flask_login import current_user from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound +from configs import dify_config from core.file.upload_file_parser import UploadFileParser from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db @@ -35,7 +35,7 @@ class FileService: extension = file.filename.split('.')[-1] if len(filename) > 200: filename = filename.split('.')[0][:200] + '.' + extension - etl_type = current_app.config['ETL_TYPE'] + etl_type = dify_config.ETL_TYPE allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if extension.lower() not in allowed_extensions: @@ -50,9 +50,9 @@ class FileService: file_size = len(file_content) if extension.lower() in IMAGE_EXTENSIONS: - file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024 + file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 else: - file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 + file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 if file_size > file_size_limit: message = f'File size exceeded. {file_size} > {file_size_limit}' @@ -73,10 +73,9 @@ class FileService: storage.save(file_key, file_content) # save file to db - config = current_app.config upload_file = UploadFile( tenant_id=current_tenant_id, - storage_type=config['STORAGE_TYPE'], + storage_type=dify_config.STORAGE_TYPE, key=file_key, name=filename, size=file_size, @@ -106,10 +105,9 @@ class FileService: storage.save(file_key, text.encode('utf-8')) # save file to db - config = current_app.config upload_file = UploadFile( tenant_id=current_user.current_tenant_id, - storage_type=config['STORAGE_TYPE'], + storage_type=dify_config.STORAGE_TYPE, key=file_key, name=text_name + '.txt', size=len(text), @@ -138,7 +136,7 @@ class FileService: # extract text from file extension = upload_file.extension - etl_type = current_app.config['ETL_TYPE'] + etl_type = dify_config.ETL_TYPE allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index d32ab2af33..c4733b6d3f 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -4,8 +4,8 @@ from os import path from typing import Optional import requests -from flask import current_app +from configs import dify_config from constants.languages import languages from extensions.ext_database import db from models.model import App, RecommendedApp @@ -25,7 +25,7 @@ class RecommendedAppService: :param language: language :return: """ - mode = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_MODE', 'remote') + mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE if mode == 'remote': try: result = cls._fetch_recommended_apps_from_dify_official(language) @@ -104,7 +104,7 @@ class RecommendedAppService: :param language: language :return: """ - domain = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN', 'https://tmpl.dify.ai') + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN url = f'{domain}/apps?language={language}' response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: @@ -134,7 +134,7 @@ class RecommendedAppService: :param app_id: app id :return: """ - mode = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_MODE', 'remote') + mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE if mode == 'remote': try: result = cls._fetch_recommended_app_detail_from_dify_official(app_id) @@ -157,7 +157,7 @@ class RecommendedAppService: :param app_id: App ID :return: """ - domain = current_app.config.get('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN', 'https://tmpl.dify.ai') + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN url = f'{domain}/apps/{app_id}' response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 5c77732468..cfce3fbd01 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -2,8 +2,7 @@ import json import logging from typing import Optional, Union -from flask import current_app - +from configs import dify_config from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -29,7 +28,7 @@ class ToolTransformService: """ get tool provider icon url """ - url_prefix = (current_app.config.get("CONSOLE_API_URL") + url_prefix = (dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/") if provider_type == ToolProviderType.BUILT_IN.value: diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 778b4e51d3..2bcbe5c6f6 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,7 +1,7 @@ -from flask import current_app from flask_login import current_user +from configs import dify_config from extensions.ext_database import db from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole from services.account_service import TenantService @@ -35,7 +35,7 @@ class WorkspaceService: if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): - base_url = current_app.config.get('FILES_URL') + base_url = dify_config.FILES_URL replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 43d1cc13f9..cc93a1341e 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -4,8 +4,8 @@ import time import click from celery import shared_task -from flask import current_app +from configs import dify_config from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db from models.dataset import Dataset, Document @@ -32,7 +32,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if features.billing.enabled: vector_space = features.vector_space count = len(document_ids) - batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 1854589e7f..884e222d1b 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -4,8 +4,8 @@ import time import click from celery import shared_task -from flask import current_app +from configs import dify_config from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -33,7 +33,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if features.billing.enabled: vector_space = features.vector_space count = len(document_ids) - batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index 1f40c05077..a46eafa797 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -3,8 +3,9 @@ import time import click from celery import shared_task -from flask import current_app, render_template +from flask import render_template +from configs import dify_config from extensions.ext_mail import mail @@ -29,7 +30,7 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam # send invite member mail using different languages try: - url = f'{current_app.config.get("CONSOLE_WEB_URL")}/activate?token={token}' + url = f'{dify_config.CONSOLE_WEB_URL}/activate?token={token}' if language == 'zh-Hans': html_content = render_template('invite_member_mail_template_zh-CN.html', to=to, diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 0e64c6f163..4e1b8a8913 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -3,8 +3,9 @@ import time import click from celery import shared_task -from flask import current_app, render_template +from flask import render_template +from configs import dify_config from extensions.ext_mail import mail @@ -24,7 +25,7 @@ def send_reset_password_mail_task(language: str, to: str, token: str): # send reset password mail using different languages try: - url = f'{current_app.config.get("CONSOLE_WEB_URL")}/forgot-password?token={token}' + url = f'{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}' if language == 'zh-Hans': html_content = render_template('reset_password_mail_template_zh-CN.html', to=to,