chore(db): use a better way to export models and remove unused table (#11838)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-20 14:12:29 +08:00 committed by GitHub
parent 2d186e1e76
commit 3599751f93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 206 additions and 91 deletions

View File

@ -5,8 +5,7 @@ from typing import Optional, Union
from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_user
from models import App from models import App, AppMode
from models.model import AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):

View File

@ -1,6 +1,5 @@
import base64 import base64
from extensions.ext_database import db
from libs import rsa from libs import rsa
@ -14,6 +13,7 @@ def obfuscated_token(token: str):
def encrypt_token(tenant_id: str, token: str): def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant from models.account import Tenant
from models.engine import db
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found") raise ValueError(f"Tenant with id {tenant_id} not found")

View File

@ -1,18 +1,5 @@
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
from dify_app import DifyApp from dify_app import DifyApp
from models import db
POSTGRES_INDEXES_NAMING_CONVENTION = {
"ix": "%(column_0_label)s_idx",
"uq": "%(table_name)s_%(column_0_name)s_key",
"ck": "%(table_name)s_%(constraint_name)s_check",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
db = SQLAlchemy(metadata=metadata)
def init_app(app: DifyApp): def init_app(app: DifyApp):

View File

@ -3,4 +3,3 @@ from dify_app import DifyApp
def init_app(app: DifyApp): def init_app(app: DifyApp):
from events import event_handlers # noqa: F401 from events import event_handlers # noqa: F401
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401

View File

@ -13,7 +13,7 @@ from typing import Any, Optional, Union, cast
from zoneinfo import available_timezones from zoneinfo import available_timezones
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import fields # type: ignore from flask_restful import fields
from configs import dify_config from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

View File

@ -0,0 +1,39 @@
"""remove unused tool_providers
Revision ID: 11b07f66c737
Revises: cf8f4fc45278
Create Date: 2024-12-19 17:46:25.780116
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '11b07f66c737'
down_revision = 'cf8f4fc45278'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tool_providers')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_providers',
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
# ### end Alembic commands ###

View File

@ -1,53 +1,187 @@
from .account import Account, AccountIntegrate, InvitationCode, Tenant from .account import (
from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment Account,
AccountIntegrate,
AccountStatus,
InvitationCode,
Tenant,
TenantAccountJoin,
TenantAccountJoinRole,
TenantAccountRole,
TenantStatus,
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .dataset import (
AppDatasetJoin,
Dataset,
DatasetCollectionBinding,
DatasetKeywordTable,
DatasetPermission,
DatasetPermissionEnum,
DatasetProcessRule,
DatasetQuery,
Document,
DocumentSegment,
Embedding,
ExternalKnowledgeApis,
ExternalKnowledgeBindings,
TidbAuthBinding,
Whitelist,
)
from .engine import db
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
from .model import ( from .model import (
ApiRequest,
ApiToken, ApiToken,
App, App,
AppAnnotationHitHistory,
AppAnnotationSetting,
AppMode, AppMode,
AppModelConfig,
Conversation, Conversation,
DatasetRetrieverResource,
DifySetup,
EndUser, EndUser,
IconType,
InstalledApp, InstalledApp,
Message, Message,
MessageAgentThought,
MessageAnnotation, MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile, MessageFile,
OperationLog,
RecommendedApp, RecommendedApp,
Site, Site,
Tag,
TagBinding,
TraceAppConfig,
UploadFile, UploadFile,
) )
from .source import DataSourceOauthBinding from .provider import (
from .tools import ToolFile LoadBalancingModelConfig,
Provider,
ProviderModel,
ProviderModelSetting,
ProviderOrder,
ProviderQuotaType,
ProviderType,
TenantDefaultModel,
TenantPreferredModelProvider,
)
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from .task import CeleryTask, CeleryTaskSet
from .tools import (
ApiToolProvider,
BuiltinToolProvider,
PublishedAppTool,
ToolConversationVariables,
ToolFile,
ToolLabelBinding,
ToolModelInvoke,
WorkflowToolProvider,
)
from .web import PinnedConversation, SavedMessage
from .workflow import ( from .workflow import (
ConversationVariable, ConversationVariable,
Workflow, Workflow,
WorkflowAppLog, WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun, WorkflowRun,
WorkflowRunStatus,
WorkflowType,
) )
__all__ = [ __all__ = [
"APIBasedExtension",
"APIBasedExtensionPoint",
"Account", "Account",
"AccountIntegrate", "AccountIntegrate",
"AccountStatus",
"ApiRequest",
"ApiToken", "ApiToken",
"ApiToolProvider", # Added
"App", "App",
"AppAnnotationHitHistory",
"AppAnnotationSetting",
"AppDatasetJoin",
"AppMode", "AppMode",
"AppModelConfig",
"BuiltinToolProvider", # Added
"CeleryTask",
"CeleryTaskSet",
"Conversation", "Conversation",
"ConversationVariable", "ConversationVariable",
"CreatedByRole",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding", "DataSourceOauthBinding",
"Dataset", "Dataset",
"DatasetCollectionBinding",
"DatasetKeywordTable",
"DatasetPermission",
"DatasetPermissionEnum",
"DatasetProcessRule", "DatasetProcessRule",
"DatasetQuery",
"DatasetRetrieverResource",
"DifySetup",
"Document", "Document",
"DocumentSegment", "DocumentSegment",
"Embedding",
"EndUser", "EndUser",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
"InstalledApp", "InstalledApp",
"InvitationCode", "InvitationCode",
"LoadBalancingModelConfig",
"Message", "Message",
"MessageAgentThought",
"MessageAnnotation", "MessageAnnotation",
"MessageChain",
"MessageFeedback",
"MessageFile", "MessageFile",
"OperationLog",
"PinnedConversation",
"Provider",
"ProviderModel",
"ProviderModelSetting",
"ProviderOrder",
"ProviderQuotaType",
"ProviderType",
"PublishedAppTool",
"RecommendedApp", "RecommendedApp",
"SavedMessage",
"Site", "Site",
"Tag",
"TagBinding",
"Tenant", "Tenant",
"TenantAccountJoin",
"TenantAccountJoinRole",
"TenantAccountRole",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
"TidbAuthBinding",
"ToolConversationVariables",
"ToolFile", "ToolFile",
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"UploadFile", "UploadFile",
"UserFrom",
"Whitelist",
"Workflow", "Workflow",
"WorkflowAppLog", "WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowNodeExecution",
"WorkflowNodeExecutionStatus",
"WorkflowNodeExecutionTriggeredFrom",
"WorkflowRun", "WorkflowRun",
"WorkflowRunStatus",
"WorkflowRunTriggeredFrom",
"WorkflowToolProvider",
"WorkflowType",
"db",
] ]

View File

@ -3,8 +3,7 @@ import json
from flask_login import UserMixin from flask_login import UserMixin
from extensions.ext_database import db from .engine import db
from .types import StringUUID from .types import StringUUID

View File

@ -1,7 +1,6 @@
import enum import enum
from extensions.ext_database import db from .engine import db
from .types import StringUUID from .types import StringUUID

View File

@ -15,10 +15,10 @@ from sqlalchemy.dialects.postgresql import JSONB
from configs import dify_config from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from .account import Account from .account import Account
from .engine import db
from .model import App, Tag, TagBinding, UploadFile from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID from .types import StringUUID

13
api/models/engine.py Normal file
View File

@ -0,0 +1,13 @@
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
POSTGRES_INDEXES_NAMING_CONVENTION = {
"ix": "%(column_0_label)s_idx",
"uq": "%(table_name)s_%(column_0_name)s_key",
"ck": "%(table_name)s_%(constraint_name)s_check",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
db = SQLAlchemy(metadata=metadata)

View File

@ -16,11 +16,11 @@ from configs import dify_config
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.file.tool_file_parser import ToolFileParser from core.file.tool_file_parser import ToolFileParser
from extensions.ext_database import db
from libs.helper import generate_string from libs.helper import generate_string
from models.enums import CreatedByRole from models.enums import CreatedByRole
from .account import Account, Tenant from .account import Account, Tenant
from .engine import db
from .types import StringUUID from .types import StringUUID

View File

@ -1,7 +1,6 @@
from enum import Enum from enum import Enum
from extensions.ext_database import db from .engine import db
from .types import StringUUID from .types import StringUUID

View File

@ -2,8 +2,7 @@ import json
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db from .engine import db
from .types import StringUUID from .types import StringUUID

View File

@ -2,7 +2,7 @@ from datetime import UTC, datetime
from celery import states from celery import states
from extensions.ext_database import db from .engine import db
class CeleryTask(db.Model): class CeleryTask(db.Model):

View File

@ -1,47 +0,0 @@
import json
from enum import Enum
from extensions.ext_database import db
from .types import StringUUID
class ToolProviderName(Enum):
SERPAPI = "serpapi"
@staticmethod
def value_of(value):
for member in ToolProviderName:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ToolProvider(db.Model):
__tablename__ = "tool_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_provider_pkey"),
db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
tool_name = db.Column(db.String(40), nullable=False)
encrypted_credentials = db.Column(db.Text, nullable=True)
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
@property
def credentials_is_set(self):
"""
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return self.encrypted_credentials is not None
@property
def credentials(self):
"""
Returns the decrypted config.
"""
return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None

View File

@ -8,8 +8,8 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db
from .engine import db
from .model import Account, App, Tenant from .model import Account, App, Tenant
from .types import StringUUID from .types import StringUUID
@ -82,7 +82,7 @@ class PublishedAppTool(db.Model):
return I18nObject(**json.loads(self.description)) return I18nObject(**json.loads(self.description))
@property @property
def app(self) -> App: def app(self):
return db.session.query(App).filter(App.id == self.app_id).first() return db.session.query(App).filter(App.id == self.app_id).first()
@ -201,10 +201,6 @@ class WorkflowToolProvider(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
@property
def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property @property
def user(self) -> Account | None: def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).filter(Account.id == self.user_id).first()

View File

@ -1,5 +1,4 @@
from extensions.ext_database import db from .engine import db
from .model import Message from .model import Message
from .types import StringUUID from .types import StringUUID

View File

@ -12,12 +12,12 @@ import contexts
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, Variable from core.variables import SecretVariable, Variable
from extensions.ext_database import db
from factories import variable_factory from factories import variable_factory
from libs import helper from libs import helper
from models.enums import CreatedByRole from models.enums import CreatedByRole
from .account import Account from .account import Account
from .engine import db
from .types import StringUUID from .types import StringUUID
@ -399,7 +399,7 @@ class WorkflowRun(db.Model):
graph = db.Column(db.Text) graph = db.Column(db.Text)
inputs = db.Column(db.Text) inputs = db.Column(db.Text)
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[str] = mapped_column(sa.Text, default="{}") outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error = db.Column(db.Text) error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))