diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c71ee8e5df..63edb83079 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,8 +5,7 @@ from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models import App -from models.model import AppMode +from models import App, AppMode def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 96341a1b78..744fce1cf9 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -1,6 +1,5 @@ import base64 -from extensions.ext_database import db from libs import rsa @@ -14,6 +13,7 @@ def obfuscated_token(token: str): def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + from models.engine import db if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): raise ValueError(f"Tenant with id {tenant_id} not found") diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index e293afa111..93842a3036 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -1,18 +1,5 @@ -from flask_sqlalchemy import SQLAlchemy -from sqlalchemy import MetaData - from dify_app import DifyApp - -POSTGRES_INDEXES_NAMING_CONVENTION = { - "ix": "%(column_0_label)s_idx", - "uq": "%(table_name)s_%(column_0_name)s_key", - "ck": "%(table_name)s_%(constraint_name)s_check", - "fk": "%(table_name)s_%(column_0_name)s_fkey", - "pk": "%(table_name)s_pkey", -} - -metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) -db = SQLAlchemy(metadata=metadata) +from models import db def init_app(app: DifyApp): diff --git a/api/extensions/ext_import_modules.py b/api/extensions/ext_import_modules.py index eefdfd3823..9566f430b6 100644 --- a/api/extensions/ext_import_modules.py +++ b/api/extensions/ext_import_modules.py @@ -3,4 +3,3 @@ from dify_app import DifyApp def init_app(app: DifyApp): from events import event_handlers # noqa: F401 - from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 diff --git a/api/libs/helper.py b/api/libs/helper.py index 026ded3506..91b1d1fe17 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -13,7 +13,7 @@ from typing import Any, Optional, Union, cast from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields # type: ignore +from flask_restful import fields from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator diff --git a/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py new file mode 100644 index 0000000000..881a9e3c1e --- /dev/null +++ b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py @@ -0,0 +1,39 @@ +"""remove unused tool_providers + +Revision ID: 11b07f66c737 +Revises: cf8f4fc45278 +Create Date: 2024-12-19 17:46:25.780116 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '11b07f66c737' +down_revision = 'cf8f4fc45278' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_providers') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_providers', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), + sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + ) + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 61a38870cf..b0b9880ca4 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,53 +1,187 @@ -from .account import Account, AccountIntegrate, InvitationCode, Tenant -from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment +from .account import ( + Account, + AccountIntegrate, + AccountStatus, + InvitationCode, + Tenant, + TenantAccountJoin, + TenantAccountJoinRole, + TenantAccountRole, + TenantStatus, +) +from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .dataset import ( + AppDatasetJoin, + Dataset, + DatasetCollectionBinding, + DatasetKeywordTable, + DatasetPermission, + DatasetPermissionEnum, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, + Embedding, + ExternalKnowledgeApis, + ExternalKnowledgeBindings, + TidbAuthBinding, + Whitelist, +) +from .engine import db +from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom from .model import ( + ApiRequest, ApiToken, App, + AppAnnotationHitHistory, + AppAnnotationSetting, AppMode, + AppModelConfig, Conversation, + DatasetRetrieverResource, + DifySetup, EndUser, + IconType, InstalledApp, Message, + MessageAgentThought, MessageAnnotation, + MessageChain, + MessageFeedback, MessageFile, + OperationLog, RecommendedApp, Site, + Tag, + TagBinding, + TraceAppConfig, UploadFile, ) -from .source import DataSourceOauthBinding -from .tools import ToolFile +from .provider import ( + LoadBalancingModelConfig, + Provider, + ProviderModel, + ProviderModelSetting, + ProviderOrder, + ProviderQuotaType, + ProviderType, + TenantDefaultModel, + TenantPreferredModelProvider, +) +from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding +from .task import CeleryTask, CeleryTaskSet +from .tools import ( + ApiToolProvider, + BuiltinToolProvider, + PublishedAppTool, + ToolConversationVariables, + ToolFile, + ToolLabelBinding, + ToolModelInvoke, + WorkflowToolProvider, +) +from .web import PinnedConversation, SavedMessage from .workflow import ( ConversationVariable, Workflow, WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, WorkflowRun, + WorkflowRunStatus, + WorkflowType, ) __all__ = [ + "APIBasedExtension", + "APIBasedExtensionPoint", "Account", "AccountIntegrate", + "AccountStatus", + "ApiRequest", "ApiToken", + "ApiToolProvider", # Added "App", + "AppAnnotationHitHistory", + "AppAnnotationSetting", + "AppDatasetJoin", "AppMode", + "AppModelConfig", + "BuiltinToolProvider", # Added + "CeleryTask", + "CeleryTaskSet", "Conversation", "ConversationVariable", + "CreatedByRole", + "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", + "DatasetCollectionBinding", + "DatasetKeywordTable", + "DatasetPermission", + "DatasetPermissionEnum", "DatasetProcessRule", + "DatasetQuery", + "DatasetRetrieverResource", + "DifySetup", "Document", "DocumentSegment", + "Embedding", "EndUser", + "ExternalKnowledgeApis", + "ExternalKnowledgeBindings", + "IconType", "InstalledApp", "InvitationCode", + "LoadBalancingModelConfig", "Message", + "MessageAgentThought", "MessageAnnotation", + "MessageChain", + "MessageFeedback", "MessageFile", + "OperationLog", + "PinnedConversation", + "Provider", + "ProviderModel", + "ProviderModelSetting", + "ProviderOrder", + "ProviderQuotaType", + "ProviderType", + "PublishedAppTool", "RecommendedApp", + "SavedMessage", "Site", + "Tag", + "TagBinding", "Tenant", + "TenantAccountJoin", + "TenantAccountJoinRole", + "TenantAccountRole", + "TenantDefaultModel", + "TenantPreferredModelProvider", + "TenantStatus", + "TidbAuthBinding", + "ToolConversationVariables", "ToolFile", + "ToolLabelBinding", + "ToolModelInvoke", + "TraceAppConfig", "UploadFile", + "UserFrom", + "Whitelist", "Workflow", "WorkflowAppLog", + "WorkflowAppLogCreatedFrom", + "WorkflowNodeExecution", + "WorkflowNodeExecutionStatus", + "WorkflowNodeExecutionTriggeredFrom", "WorkflowRun", + "WorkflowRunStatus", + "WorkflowRunTriggeredFrom", + "WorkflowToolProvider", + "WorkflowType", + "db", ] diff --git a/api/models/account.py b/api/models/account.py index 951e836dec..ce17b90def 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -3,8 +3,7 @@ import json from flask_login import UserMixin -from extensions.ext_database import db - +from .engine import db from .types import StringUUID diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 97173747af..4d4182cabd 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,6 @@ import enum -from extensions.ext_database import db - +from .engine import db from .types import StringUUID diff --git a/api/models/dataset.py b/api/models/dataset.py index 8ab957e875..97e4d6c0ef 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -15,10 +15,10 @@ from sqlalchemy.dialects.postgresql import JSONB from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod -from extensions.ext_database import db from extensions.ext_storage import storage from .account import Account +from .engine import db from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID diff --git a/api/models/engine.py b/api/models/engine.py new file mode 100644 index 0000000000..dda93bc941 --- /dev/null +++ b/api/models/engine.py @@ -0,0 +1,13 @@ +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import MetaData + +POSTGRES_INDEXES_NAMING_CONVENTION = { + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", +} + +metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) +db = SQLAlchemy(metadata=metadata) diff --git a/api/models/model.py b/api/models/model.py index 03b8e0bea5..54b719628c 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -16,11 +16,11 @@ from configs import dify_config from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser -from extensions.ext_database import db from libs.helper import generate_string from models.enums import CreatedByRole from .account import Account, Tenant +from .engine import db from .types import StringUUID diff --git a/api/models/provider.py b/api/models/provider.py index 644915e781..65f70b76e9 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,7 +1,6 @@ from enum import Enum -from extensions.ext_database import db - +from .engine import db from .types import StringUUID diff --git a/api/models/source.py b/api/models/source.py index 07695f06e6..4d98572ef8 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -2,8 +2,7 @@ import json from sqlalchemy.dialects.postgresql import JSONB -from extensions.ext_database import db - +from .engine import db from .types import StringUUID diff --git a/api/models/task.py b/api/models/task.py index 5d89ff85ac..27571e2474 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime from celery import states -from extensions.ext_database import db +from .engine import db class CeleryTask(db.Model): diff --git a/api/models/tool.py b/api/models/tool.py deleted file mode 100644 index a81bb65174..0000000000 --- a/api/models/tool.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -from enum import Enum - -from extensions.ext_database import db - -from .types import StringUUID - - -class ToolProviderName(Enum): - SERPAPI = "serpapi" - - @staticmethod - def value_of(value): - for member in ToolProviderName: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class ToolProvider(db.Model): - __tablename__ = "tool_providers" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), - db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), - ) - - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - tool_name = db.Column(db.String(40), nullable=False) - encrypted_credentials = db.Column(db.Text, nullable=True) - is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - - @property - def credentials_is_set(self): - """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ - return self.encrypted_credentials is not None - - @property - def credentials(self): - """ - Returns the decrypted config. - """ - return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None diff --git a/api/models/tools.py b/api/models/tools.py index 4040339e02..c390be4625 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -8,8 +8,8 @@ from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from extensions.ext_database import db +from .engine import db from .model import Account, App, Tenant from .types import StringUUID @@ -82,7 +82,7 @@ class PublishedAppTool(db.Model): return I18nObject(**json.loads(self.description)) @property - def app(self) -> App: + def app(self): return db.session.query(App).filter(App.id == self.app_id).first() @@ -201,10 +201,6 @@ class WorkflowToolProvider(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - @property - def schema_type(self) -> ApiProviderSchemaType: - return ApiProviderSchemaType.value_of(self.schema_type_str) - @property def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() diff --git a/api/models/web.py b/api/models/web.py index bc088c185d..a0f87cf456 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,5 +1,4 @@ -from extensions.ext_database import db - +from .engine import db from .model import Message from .types import StringUUID diff --git a/api/models/workflow.py b/api/models/workflow.py index 09e3728d7c..4ddf2e9082 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -12,12 +12,12 @@ import contexts from constants import HIDDEN_VALUE from core.helper import encrypter from core.variables import SecretVariable, Variable -from extensions.ext_database import db from factories import variable_factory from libs import helper from models.enums import CreatedByRole from .account import Account +from .engine import db from .types import StringUUID @@ -399,7 +399,7 @@ class WorkflowRun(db.Model): graph = db.Column(db.Text) inputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded - outputs: Mapped[str] = mapped_column(sa.Text, default="{}") + outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))