refactor(models): Use the SQLAlchemy base model. (#19435)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-05-09 13:52:05 +08:00 committed by GitHub
parent 2ad7305349
commit 792b321a81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 61 additions and 87 deletions

View File

@ -1,5 +1,6 @@
import enum import enum
import json import json
from typing import cast
from flask_login import UserMixin # type: ignore from flask_login import UserMixin # type: ignore
from sqlalchemy import func from sqlalchemy import func
@ -46,7 +47,6 @@ class Account(UserMixin, Base):
@property @property
def current_tenant(self): def current_tenant(self):
# FIXME: fix the type error later, because the type is important maybe cause some bugs
return self._current_tenant # type: ignore return self._current_tenant # type: ignore
@current_tenant.setter @current_tenant.setter
@ -64,25 +64,23 @@ class Account(UserMixin, Base):
def current_tenant_id(self) -> str | None: def current_tenant_id(self) -> str | None:
return self._current_tenant.id if self._current_tenant else None return self._current_tenant.id if self._current_tenant else None
@current_tenant_id.setter def set_tenant_id(self, tenant_id: str):
def current_tenant_id(self, value: str): tenant_account_join = cast(
try: tuple[Tenant, TenantAccountJoin],
tenant_account_join = ( (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == value) .filter(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id) .filter(TenantAccountJoin.account_id == self.id)
.one_or_none() .one_or_none()
) ),
)
if tenant_account_join: if not tenant_account_join:
tenant, ta = tenant_account_join return
tenant.current_role = ta.role
else:
tenant = None
except Exception:
tenant = None
tenant, join = tenant_account_join
tenant.current_role = join.role
self._current_tenant = tenant self._current_tenant = tenant
@property @property
@ -191,7 +189,7 @@ class TenantAccountRole(enum.StrEnum):
} }
class Tenant(db.Model): # type: ignore[name-defined] class Tenant(Base):
__tablename__ = "tenants" __tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
@ -220,7 +218,7 @@ class Tenant(db.Model): # type: ignore[name-defined]
self.custom_config = json.dumps(value) self.custom_config = json.dumps(value)
class TenantAccountJoin(db.Model): # type: ignore[name-defined] class TenantAccountJoin(Base):
__tablename__ = "tenant_account_joins" __tablename__ = "tenant_account_joins"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@ -239,7 +237,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AccountIntegrate(db.Model): # type: ignore[name-defined] class AccountIntegrate(Base):
__tablename__ = "account_integrates" __tablename__ = "account_integrates"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@ -256,7 +254,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class InvitationCode(db.Model): # type: ignore[name-defined] class InvitationCode(Base):
__tablename__ = "invitation_codes" __tablename__ = "invitation_codes"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),

View File

@ -2,6 +2,7 @@ import enum
from sqlalchemy import func from sqlalchemy import func
from .base import Base
from .engine import db from .engine import db
from .types import StringUUID from .types import StringUUID
@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum):
APP_MODERATION_OUTPUT = "app.moderation.output" APP_MODERATION_OUTPUT = "app.moderation.output"
class APIBasedExtension(db.Model): # type: ignore[name-defined] class APIBasedExtension(Base):
__tablename__ = "api_based_extensions" __tablename__ = "api_based_extensions"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),

View File

@ -22,6 +22,7 @@ from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account from .account import Account
from .base import Base
from .engine import db from .engine import db
from .model import App, Tag, TagBinding, UploadFile from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID from .types import StringUUID
@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum):
PARTIAL_TEAM = "partial_members" PARTIAL_TEAM = "partial_members"
class Dataset(db.Model): # type: ignore[name-defined] class Dataset(Base):
__tablename__ = "datasets" __tablename__ = "datasets"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"), db.PrimaryKeyConstraint("id", name="dataset_pkey"),
@ -255,7 +256,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
return f"Vector_index_{normalized_dataset_id}_Node" return f"Vector_index_{normalized_dataset_id}_Node"
class DatasetProcessRule(db.Model): # type: ignore[name-defined] class DatasetProcessRule(Base):
__tablename__ = "dataset_process_rules" __tablename__ = "dataset_process_rules"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
@ -295,7 +296,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
return None return None
class Document(db.Model): # type: ignore[name-defined] class Document(Base):
__tablename__ = "documents" __tablename__ = "documents"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"), db.PrimaryKeyConstraint("id", name="document_pkey"),
@ -635,7 +636,7 @@ class Document(db.Model): # type: ignore[name-defined]
) )
class DocumentSegment(db.Model): # type: ignore[name-defined] class DocumentSegment(Base):
__tablename__ = "document_segments" __tablename__ = "document_segments"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"), db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@ -786,7 +787,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text return text
class ChildChunk(db.Model): # type: ignore[name-defined] class ChildChunk(Base):
__tablename__ = "child_chunks" __tablename__ = "child_chunks"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -829,7 +830,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined]
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model): # type: ignore[name-defined] class AppDatasetJoin(Base):
__tablename__ = "app_dataset_joins" __tablename__ = "app_dataset_joins"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
@ -846,7 +847,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined]
return db.session.get(App, self.app_id) return db.session.get(App, self.app_id)
class DatasetQuery(db.Model): # type: ignore[name-defined] class DatasetQuery(Base):
__tablename__ = "dataset_queries" __tablename__ = "dataset_queries"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
@ -863,7 +864,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(db.Model): # type: ignore[name-defined] class DatasetKeywordTable(Base):
__tablename__ = "dataset_keyword_tables" __tablename__ = "dataset_keyword_tables"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
@ -908,7 +909,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return None return None
class Embedding(db.Model): # type: ignore[name-defined] class Embedding(Base):
__tablename__ = "embeddings" __tablename__ = "embeddings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"), db.PrimaryKeyConstraint("id", name="embedding_pkey"),
@ -932,7 +933,7 @@ class Embedding(db.Model): # type: ignore[name-defined]
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301 return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] class DatasetCollectionBinding(Base):
__tablename__ = "dataset_collection_bindings" __tablename__ = "dataset_collection_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
@ -947,7 +948,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(db.Model): # type: ignore[name-defined] class TidbAuthBinding(Base):
__tablename__ = "tidb_auth_bindings" __tablename__ = "tidb_auth_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@ -967,7 +968,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(db.Model): # type: ignore[name-defined] class Whitelist(Base):
__tablename__ = "whitelists" __tablename__ = "whitelists"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"), db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
@ -979,7 +980,7 @@ class Whitelist(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(db.Model): # type: ignore[name-defined] class DatasetPermission(Base):
__tablename__ = "dataset_permissions" __tablename__ = "dataset_permissions"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
@ -996,7 +997,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] class ExternalKnowledgeApis(Base):
__tablename__ = "external_knowledge_apis" __tablename__ = "external_knowledge_apis"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
@ -1049,7 +1050,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
return dataset_bindings return dataset_bindings
class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] class ExternalKnowledgeBindings(Base):
__tablename__ = "external_knowledge_bindings" __tablename__ = "external_knowledge_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
@ -1070,7 +1071,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] class DatasetAutoDisableLog(Base):
__tablename__ = "dataset_auto_disable_logs" __tablename__ = "dataset_auto_disable_logs"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
@ -1087,7 +1088,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class RateLimitLog(db.Model): # type: ignore[name-defined] class RateLimitLog(Base):
__tablename__ = "rate_limit_logs" __tablename__ = "rate_limit_logs"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
@ -1102,7 +1103,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class DatasetMetadata(db.Model): # type: ignore[name-defined] class DatasetMetadata(Base):
__tablename__ = "dataset_metadatas" __tablename__ = "dataset_metadatas"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
@ -1121,7 +1122,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined]
updated_by = db.Column(StringUUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] class DatasetMetadataBinding(Base):
__tablename__ = "dataset_metadata_bindings" __tablename__ = "dataset_metadata_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin # type: ignore from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
@ -25,13 +25,13 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from libs.helper import generate_string from libs.helper import generate_string
from models.base import Base
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus
from .account import Account, Tenant from .account import Account, Tenant
from .base import Base
from .engine import db from .engine import db
from .enums import CreatedByRole
from .types import StringUUID from .types import StringUUID
from .workflow import WorkflowRunStatus
if TYPE_CHECKING: if TYPE_CHECKING:
from .workflow import Workflow from .workflow import Workflow
@ -602,7 +602,7 @@ class InstalledApp(Base):
return tenant return tenant
class Conversation(db.Model): # type: ignore[name-defined] class Conversation(Base):
__tablename__ = "conversations" __tablename__ = "conversations"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"), db.PrimaryKeyConstraint("id", name="conversation_pkey"),
@ -794,7 +794,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
for message in messages: for message in messages:
if message.workflow_run: if message.workflow_run:
status_counts[message.workflow_run.status] += 1 status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1
return ( return (
{ {
@ -864,7 +864,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
} }
class Message(db.Model): # type: ignore[name-defined] class Message(Base):
__tablename__ = "messages" __tablename__ = "messages"
__table_args__ = ( __table_args__ = (
PrimaryKeyConstraint("id", name="message_pkey"), PrimaryKeyConstraint("id", name="message_pkey"),
@ -1211,7 +1211,7 @@ class Message(db.Model): # type: ignore[name-defined]
) )
class MessageFeedback(db.Model): # type: ignore[name-defined] class MessageFeedback(Base):
__tablename__ = "message_feedbacks" __tablename__ = "message_feedbacks"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@ -1238,7 +1238,7 @@ class MessageFeedback(db.Model): # type: ignore[name-defined]
return account return account
class MessageFile(db.Model): # type: ignore[name-defined] class MessageFile(Base):
__tablename__ = "message_files" __tablename__ = "message_files"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"), db.PrimaryKeyConstraint("id", name="message_file_pkey"),
@ -1279,7 +1279,7 @@ class MessageFile(db.Model): # type: ignore[name-defined]
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(db.Model): # type: ignore[name-defined] class MessageAnnotation(Base):
__tablename__ = "message_annotations" __tablename__ = "message_annotations"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@ -1310,7 +1310,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined]
return account return account
class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories" __tablename__ = "app_annotation_hit_histories"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@ -1322,7 +1322,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
annotation_id = db.Column(StringUUID, nullable=False) annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False)
source = db.Column(db.Text, nullable=False) source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False)
account_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
@ -1348,7 +1348,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
return account return account
class AppAnnotationSetting(db.Model): # type: ignore[name-defined] class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings" __tablename__ = "app_annotation_settings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@ -1364,26 +1364,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
updated_user_id = db.Column(StringUUID, nullable=False) updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def updated_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property @property
def collection_binding_detail(self): def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding from .dataset import DatasetCollectionBinding

View File

@ -2,8 +2,7 @@ from enum import Enum
from sqlalchemy import func from sqlalchemy import func
from models.base import Base from .base import Base
from .engine import db from .engine import db
from .types import StringUUID from .types import StringUUID

View File

@ -9,7 +9,7 @@ from .engine import db
from .types import StringUUID from .types import StringUUID
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings" __tablename__ = "data_source_oauth_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"), db.PrimaryKeyConstraint("id", name="source_binding_pkey"),

View File

@ -9,7 +9,7 @@ if TYPE_CHECKING:
from models.model import AppMode from models.model import AppMode
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, func from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
import contexts import contexts
@ -18,11 +18,11 @@ from core.helper import encrypter
from core.variables import SecretVariable, Variable from core.variables import SecretVariable, Variable
from factories import variable_factory from factories import variable_factory
from libs import helper from libs import helper
from models.base import Base
from models.enums import CreatedByRole
from .account import Account from .account import Account
from .base import Base
from .engine import db from .engine import db
from .enums import CreatedByRole
from .types import StringUUID from .types import StringUUID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -768,17 +768,12 @@ class WorkflowAppLog(Base):
class ConversationVariable(Base): class ConversationVariable(Base):
__tablename__ = "workflow_conversation_variables" __tablename__ = "workflow_conversation_variables"
__table_args__ = (
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
Index("workflow__conversation_variables_app_id_idx", "app_id"),
Index("workflow__conversation_variables_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True) id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data = mapped_column(db.Text, nullable=False) data = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True)
updated_at = mapped_column( updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
) )

View File

@ -110,7 +110,7 @@ class AccountService:
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
if current_tenant: if current_tenant:
account.current_tenant_id = current_tenant.tenant_id account.set_tenant_id(current_tenant.tenant_id)
else: else:
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
@ -118,7 +118,7 @@ class AccountService:
if not available_ta: if not available_ta:
return None return None
account.current_tenant_id = available_ta.tenant_id account.set_tenant_id(available_ta.tenant_id)
available_ta.current = True available_ta.current = True
db.session.commit() db.session.commit()
@ -700,7 +700,7 @@ class TenantService:
).update({"current": False}) ).update({"current": False})
tenant_account_join.current = True tenant_account_join.current = True
# Set the current tenant for the account # Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id account.set_tenant_id(tenant_account_join.tenant_id)
db.session.commit() db.session.commit()
@staticmethod @staticmethod