refactor(models): Use the SQLAlchemy base model. (#19435)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-05-09 13:52:05 +08:00 committed by GitHub
parent 2ad7305349
commit 792b321a81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 61 additions and 87 deletions

View File

@ -1,5 +1,6 @@
import enum
import json
from typing import cast
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
@ -46,7 +47,6 @@ class Account(UserMixin, Base):
@property
def current_tenant(self):
# FIXME: fix the type error later, because the type is important maybe cause some bugs
return self._current_tenant # type: ignore
@current_tenant.setter
@ -64,25 +64,23 @@ class Account(UserMixin, Base):
def current_tenant_id(self) -> str | None:
return self._current_tenant.id if self._current_tenant else None
@current_tenant_id.setter
def current_tenant_id(self, value: str):
try:
tenant_account_join = (
def set_tenant_id(self, tenant_id: str):
tenant_account_join = cast(
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == value)
.filter(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id)
.one_or_none()
)
),
)
if tenant_account_join:
tenant, ta = tenant_account_join
tenant.current_role = ta.role
else:
tenant = None
except Exception:
tenant = None
if not tenant_account_join:
return
tenant, join = tenant_account_join
tenant.current_role = join.role
self._current_tenant = tenant
@property
@ -191,7 +189,7 @@ class TenantAccountRole(enum.StrEnum):
}
class Tenant(db.Model): # type: ignore[name-defined]
class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
@ -220,7 +218,7 @@ class Tenant(db.Model): # type: ignore[name-defined]
self.custom_config = json.dumps(value)
class TenantAccountJoin(db.Model): # type: ignore[name-defined]
class TenantAccountJoin(Base):
__tablename__ = "tenant_account_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@ -239,7 +237,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AccountIntegrate(db.Model): # type: ignore[name-defined]
class AccountIntegrate(Base):
__tablename__ = "account_integrates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@ -256,7 +254,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class InvitationCode(db.Model): # type: ignore[name-defined]
class InvitationCode(Base):
__tablename__ = "invitation_codes"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),

View File

@ -2,6 +2,7 @@ import enum
from sqlalchemy import func
from .base import Base
from .engine import db
from .types import StringUUID
@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum):
APP_MODERATION_OUTPUT = "app.moderation.output"
class APIBasedExtension(db.Model): # type: ignore[name-defined]
class APIBasedExtension(Base):
__tablename__ = "api_based_extensions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),

View File

@ -22,6 +22,7 @@ from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account
from .base import Base
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum):
PARTIAL_TEAM = "partial_members"
class Dataset(db.Model): # type: ignore[name-defined]
class Dataset(Base):
__tablename__ = "datasets"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
@ -255,7 +256,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
return f"Vector_index_{normalized_dataset_id}_Node"
class DatasetProcessRule(db.Model): # type: ignore[name-defined]
class DatasetProcessRule(Base):
__tablename__ = "dataset_process_rules"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
@ -295,7 +296,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
return None
class Document(db.Model): # type: ignore[name-defined]
class Document(Base):
__tablename__ = "documents"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"),
@ -635,7 +636,7 @@ class Document(db.Model): # type: ignore[name-defined]
)
class DocumentSegment(db.Model): # type: ignore[name-defined]
class DocumentSegment(Base):
__tablename__ = "document_segments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@ -786,7 +787,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text
class ChildChunk(db.Model): # type: ignore[name-defined]
class ChildChunk(Base):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -829,7 +830,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined]
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model): # type: ignore[name-defined]
class AppDatasetJoin(Base):
__tablename__ = "app_dataset_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
@ -846,7 +847,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined]
return db.session.get(App, self.app_id)
class DatasetQuery(db.Model): # type: ignore[name-defined]
class DatasetQuery(Base):
__tablename__ = "dataset_queries"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
@ -863,7 +864,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
class DatasetKeywordTable(Base):
__tablename__ = "dataset_keyword_tables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
@ -908,7 +909,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return None
class Embedding(db.Model): # type: ignore[name-defined]
class Embedding(Base):
__tablename__ = "embeddings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
@ -932,7 +933,7 @@ class Embedding(db.Model): # type: ignore[name-defined]
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
class DatasetCollectionBinding(Base):
__tablename__ = "dataset_collection_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
@ -947,7 +948,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(db.Model): # type: ignore[name-defined]
class TidbAuthBinding(Base):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@ -967,7 +968,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(db.Model): # type: ignore[name-defined]
class Whitelist(Base):
__tablename__ = "whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
@ -979,7 +980,7 @@ class Whitelist(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(db.Model): # type: ignore[name-defined]
class DatasetPermission(Base):
__tablename__ = "dataset_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
@ -996,7 +997,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
class ExternalKnowledgeApis(Base):
__tablename__ = "external_knowledge_apis"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
@ -1049,7 +1050,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
return dataset_bindings
class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
class ExternalKnowledgeBindings(Base):
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
@ -1070,7 +1071,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
class DatasetAutoDisableLog(Base):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
@ -1087,7 +1088,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class RateLimitLog(db.Model): # type: ignore[name-defined]
class RateLimitLog(Base):
__tablename__ = "rate_limit_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
@ -1102,7 +1103,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class DatasetMetadata(db.Model): # type: ignore[name-defined]
class DatasetMetadata(Base):
__tablename__ = "dataset_metadatas"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
@ -1121,7 +1122,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined]
updated_by = db.Column(StringUUID, nullable=True)
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
class DatasetMetadataBinding(Base):
__tablename__ = "dataset_metadata_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column
@ -25,13 +25,13 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from libs.helper import generate_string
from models.base import Base
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus
from .account import Account, Tenant
from .base import Base
from .engine import db
from .enums import CreatedByRole
from .types import StringUUID
from .workflow import WorkflowRunStatus
if TYPE_CHECKING:
from .workflow import Workflow
@ -602,7 +602,7 @@ class InstalledApp(Base):
return tenant
class Conversation(db.Model): # type: ignore[name-defined]
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
@ -794,7 +794,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
for message in messages:
if message.workflow_run:
status_counts[message.workflow_run.status] += 1
status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1
return (
{
@ -864,7 +864,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
}
class Message(db.Model): # type: ignore[name-defined]
class Message(Base):
__tablename__ = "messages"
__table_args__ = (
PrimaryKeyConstraint("id", name="message_pkey"),
@ -1211,7 +1211,7 @@ class Message(db.Model): # type: ignore[name-defined]
)
class MessageFeedback(db.Model): # type: ignore[name-defined]
class MessageFeedback(Base):
__tablename__ = "message_feedbacks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@ -1238,7 +1238,7 @@ class MessageFeedback(db.Model): # type: ignore[name-defined]
return account
class MessageFile(db.Model): # type: ignore[name-defined]
class MessageFile(Base):
__tablename__ = "message_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
@ -1279,7 +1279,7 @@ class MessageFile(db.Model): # type: ignore[name-defined]
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(db.Model): # type: ignore[name-defined]
class MessageAnnotation(Base):
__tablename__ = "message_annotations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@ -1310,7 +1310,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined]
return account
class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@ -1322,7 +1322,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
annotation_id = db.Column(StringUUID, nullable=False)
annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False)
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
@ -1348,7 +1348,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
return account
class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@ -1364,26 +1364,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def updated_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding

View File

@ -2,8 +2,7 @@ from enum import Enum
from sqlalchemy import func
from models.base import Base
from .base import Base
from .engine import db
from .types import StringUUID

View File

@ -9,7 +9,7 @@ from .engine import db
from .types import StringUUID
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),

View File

@ -9,7 +9,7 @@ if TYPE_CHECKING:
from models.model import AppMode
import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, func
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
import contexts
@ -18,11 +18,11 @@ from core.helper import encrypter
from core.variables import SecretVariable, Variable
from factories import variable_factory
from libs import helper
from models.base import Base
from models.enums import CreatedByRole
from .account import Account
from .base import Base
from .engine import db
from .enums import CreatedByRole
from .types import StringUUID
if TYPE_CHECKING:
@ -768,17 +768,12 @@ class WorkflowAppLog(Base):
class ConversationVariable(Base):
__tablename__ = "workflow_conversation_variables"
__table_args__ = (
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
Index("workflow__conversation_variables_app_id_idx", "app_id"),
Index("workflow__conversation_variables_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True)
updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

View File

@ -110,7 +110,7 @@ class AccountService:
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
@ -118,7 +118,7 @@ class AccountService:
if not available_ta:
return None
account.current_tenant_id = available_ta.tenant_id
account.set_tenant_id(available_ta.tenant_id)
available_ta.current = True
db.session.commit()
@ -700,7 +700,7 @@ class TenantService:
).update({"current": False})
tenant_account_join.current = True
# Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id
account.set_tenant_id(tenant_account_join.tenant_id)
db.session.commit()
@staticmethod