refactor: Use typed SQLAlchemy base model and fix type errors (#19980)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-05-21 15:38:03 +08:00 committed by GitHub
parent ef3569e667
commit 3196dc2d61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 272 additions and 174 deletions

View File

@ -202,18 +202,18 @@ class EmailCodeLoginApi(Resource):
except AccountRegisterError as are: except AccountRegisterError as are:
raise AccountInFreezeError() raise AccountInFreezeError()
if account: if account:
tenant = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
if not tenant: if not tenants:
workspaces = FeatureService.get_system_features().license.workspaces workspaces = FeatureService.get_system_features().license.workspaces
if not workspaces.is_available(): if not workspaces.is_available():
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
if not FeatureService.get_system_features().is_allow_create_workspace: if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace() raise NotAllowedCreateWorkspace()
else: else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(new_tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = new_tenant
tenant_was_created.send(tenant) tenant_was_created.send(new_tenant)
if account is None: if account is None:
try: try:

View File

@ -148,15 +148,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
account = _get_account_by_openid_or_email(provider, user_info) account = _get_account_by_openid_or_email(provider, user_info)
if account: if account:
tenant = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
if not tenant: if not tenants:
if not FeatureService.get_system_features().is_allow_create_workspace: if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError() raise WorkSpaceNotAllowedCreateError()
else: else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(new_tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = new_tenant
tenant_was_created.send(tenant) tenant_was_created.send(new_tenant)
if not account: if not account:
if not FeatureService.get_system_features().is_allow_register: if not FeatureService.get_system_features().is_allow_register:

View File

@ -540,9 +540,22 @@ class DatasetIndexingStatusApi(Resource):
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments # Create a dictionary with document attributes and additional fields
document.total_segments = total_segments document_dict = {
documents_status.append(marshal(document, document_status_fields)) "id": document.id,
"indexing_status": document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status} data = {"data": documents_status}
return data return data

View File

@ -583,11 +583,22 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments # Create a dictionary with document attributes and additional fields
document.total_segments = total_segments document_dict = {
if document.is_paused: "id": document.id,
document.indexing_status = "paused" "indexing_status": "paused" if document.is_paused else document.indexing_status,
documents_status.append(marshal(document, document_status_fields)) "processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status} data = {"data": documents_status}
return data return data
@ -616,11 +627,22 @@ class DocumentIndexingStatusApi(DocumentResource):
.count() .count()
) )
document.completed_segments = completed_segments # Create a dictionary with document attributes and additional fields
document.total_segments = total_segments document_dict = {
if document.is_paused: "id": document.id,
document.indexing_status = "paused" "indexing_status": "paused" if document.is_paused else document.indexing_status,
return marshal(document, document_status_fields) "processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
return marshal(document_dict, document_status_fields)
class DocumentDetailApi(DocumentResource): class DocumentDetailApi(DocumentResource):

View File

@ -68,16 +68,24 @@ class TenantListApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
tenants = TenantService.get_join_tenants(current_user) tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
for tenant in tenants: for tenant in tenants:
features = FeatureService.get_features(tenant.id) features = FeatureService.get_features(tenant.id)
if features.billing.enabled:
tenant.plan = features.billing.subscription.plan # Create a dictionary with tenant attributes
else: tenant_dict = {
tenant.plan = "sandbox" "id": tenant.id,
if tenant.id == current_user.current_tenant_id: "name": tenant.name,
tenant.current = True # Set current=True for current tenant "status": tenant.status,
return {"workspaces": marshal(tenants, tenants_fields)}, 200 "created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id,
}
tenant_dicts.append(tenant_dict)
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
class WorkspaceListApi(Resource): class WorkspaceListApi(Resource):

View File

@ -64,9 +64,24 @@ class PluginUploadFileApi(Resource):
extension = guess_extension(tool_file.mimetype) or ".bin" extension = guess_extension(tool_file.mimetype) or ".bin"
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
tool_file.mime_type = mimetype
tool_file.extension = extension # Create a dictionary with all the necessary attributes
tool_file.preview_url = preview_url result = {
"id": tool_file.id,
"user_id": tool_file.user_id,
"tenant_id": tool_file.tenant_id,
"conversation_id": tool_file.conversation_id,
"file_key": tool_file.file_key,
"mimetype": tool_file.mimetype,
"original_url": tool_file.original_url,
"name": tool_file.name,
"size": tool_file.size,
"mime_type": mimetype,
"extension": extension,
"preview_url": preview_url,
}
return result, 201
except services.errors.file.FileTooLargeError as file_too_large_error: except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:

View File

@ -388,11 +388,22 @@ class DocumentIndexingStatusApi(DatasetApiResource):
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments # Create a dictionary with document attributes and additional fields
document.total_segments = total_segments document_dict = {
if document.is_paused: "id": document.id,
document.indexing_status = "paused" "indexing_status": "paused" if document.is_paused else document.indexing_status,
documents_status.append(marshal(document, document_status_fields)) "processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status} data = {"data": documents_status}
return data return data

View File

@ -405,7 +405,29 @@ class RetrievalService:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] record["score"] = segment_child_map[record["segment"].id]["max_score"]
return [RetrievalSegments(**record) for record in records] result = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None
# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
float(score_value)
if score_value is not None and isinstance(score_value, int | float | str)
else None
)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
result.append(retrieval_segment)
return result
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
raise e raise e

View File

@ -528,7 +528,7 @@ class ToolManager:
yield provider yield provider
except Exception: except Exception:
logger.exception(f"load builtin provider {provider}") logger.exception(f"load builtin provider {provider_path}")
continue continue
# set builtin providers loaded # set builtin providers loaded
cls._builtin_providers_loaded = True cls._builtin_providers_loaded = True
@ -644,10 +644,10 @@ class ToolManager:
) )
workflow_provider_controllers: list[WorkflowToolProviderController] = [] workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers: for workflow_provider in workflow_providers:
try: try:
workflow_provider_controllers.append( workflow_provider_controllers.append(
ToolTransformService.workflow_provider_to_controller(db_provider=provider) ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
) )
except Exception: except Exception:
# app has been deleted # app has been deleted

View File

@ -1,7 +1,9 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union, cast from typing import Any, Optional, cast
from flask_login import current_user
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
@ -87,7 +89,7 @@ class WorkflowTool(Tool):
result = generator.generate( result = generator.generate(
app_model=app, app_model=app,
workflow=workflow, workflow=workflow,
user=self._get_user(user_id), user=cast("Account | EndUser", current_user),
args={"inputs": tool_parameters, "files": files}, args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from, invoke_from=self.runtime.invoke_from,
streaming=False, streaming=False,
@ -111,20 +113,6 @@ class WorkflowTool(Tool):
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs) yield self.create_json_message(outputs)
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""
get the user by user id
"""
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()
if not user:
raise ValueError("user not found")
return user
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
""" """
fork a new tool with metadata fork a new tool with metadata

View File

@ -3,11 +3,14 @@ import json
import flask_login # type: ignore import flask_login # type: ignore
from flask import Response, request from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
import contexts import contexts
from dify_app import DifyApp from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from models.account import Account
from models.model import EndUser
from services.account_service import AccountService from services.account_service import AccountService
login_manager = flask_login.LoginManager() login_manager = flask_login.LoginManager()
@ -17,34 +20,48 @@ login_manager = flask_login.LoginManager()
@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
if not auth_header: auth_token: str | None = None
auth_token = request.args.get("_token") if auth_header:
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
else:
if " " not in auth_header: if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme, auth_token = auth_header.split(maxsplit=1)
auth_scheme = auth_scheme.lower() auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
else:
auth_token = request.args.get("_token")
decoded = PassportService().verify(auth_token) if request.blueprint in {"console", "inner_api"}:
user_id = decoded.get("user_id") if not auth_token:
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
if not user_id:
raise Unauthorized("Invalid Authorization token.")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id) logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account return logged_in_account
elif request.blueprint == "web":
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
@user_logged_in.connect @user_logged_in.connect
@user_loaded_from_request.connect @user_loaded_from_request.connect
def on_user_logged_in(_sender, user): def on_user_logged_in(_sender, user):
"""Called when a user logged in.""" """Called when a user logged in.
if user:
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
through the load_user method, which calls account.set_tenant_id().
"""
if user and isinstance(user, Account) and user.current_tenant_id:
contexts.tenant_id.set(user.current_tenant_id) contexts.tenant_id.set(user.current_tenant_id)

View File

@ -1,10 +1,10 @@
import enum import enum
import json import json
from typing import cast from typing import Optional, cast
from flask_login import UserMixin # type: ignore from flask_login import UserMixin # type: ignore
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base from models.base import Base
@ -12,6 +12,66 @@ from .engine import db
from .types import StringUUID from .types import StringUUID
class TenantAccountRole(enum.StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
NORMAL = "normal"
DATASET_OPERATOR = "dataset_operator"
@staticmethod
def is_valid_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
@staticmethod
def is_admin_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN
@staticmethod
def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_editing_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod
def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR,
}
class AccountStatus(enum.StrEnum): class AccountStatus(enum.StrEnum):
PENDING = "pending" PENDING = "pending"
UNINITIALIZED = "uninitialized" UNINITIALIZED = "uninitialized"
@ -41,24 +101,27 @@ class Account(UserMixin, Base):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@reconstructor
def init_on_load(self):
self.role: Optional[TenantAccountRole] = None
self._current_tenant: Optional[Tenant] = None
@property @property
def is_password_set(self): def is_password_set(self):
return self.password is not None return self.password is not None
@property @property
def current_tenant(self): def current_tenant(self):
return self._current_tenant # type: ignore return self._current_tenant
@current_tenant.setter @current_tenant.setter
def current_tenant(self, value: "Tenant"): def current_tenant(self, tenant: "Tenant"):
tenant = value
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
if ta: if ta:
tenant.current_role = ta.role self.role = TenantAccountRole(ta.role)
else: self._current_tenant = tenant
tenant = None # type: ignore return
self._current_tenant = None
self._current_tenant = tenant
@property @property
def current_tenant_id(self) -> str | None: def current_tenant_id(self) -> str | None:
@ -80,12 +143,12 @@ class Account(UserMixin, Base):
return return
tenant, join = tenant_account_join tenant, join = tenant_account_join
tenant.current_role = join.role self.role = join.role
self._current_tenant = tenant self._current_tenant = tenant
@property @property
def current_role(self): def current_role(self):
return self._current_tenant.current_role return self.role
def get_status(self) -> AccountStatus: def get_status(self) -> AccountStatus:
status_str = self.status status_str = self.status
@ -105,23 +168,23 @@ class Account(UserMixin, Base):
# check current_user.current_tenant.current_role in ['admin', 'owner'] # check current_user.current_tenant.current_role in ['admin', 'owner']
@property @property
def is_admin_or_owner(self): def is_admin_or_owner(self):
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) return TenantAccountRole.is_privileged_role(self.role)
@property @property
def is_admin(self): def is_admin(self):
return TenantAccountRole.is_admin_role(self._current_tenant.current_role) return TenantAccountRole.is_admin_role(self.role)
@property @property
def is_editor(self): def is_editor(self):
return TenantAccountRole.is_editing_role(self._current_tenant.current_role) return TenantAccountRole.is_editing_role(self.role)
@property @property
def is_dataset_editor(self): def is_dataset_editor(self):
return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role) return TenantAccountRole.is_dataset_edit_role(self.role)
@property @property
def is_dataset_operator(self): def is_dataset_operator(self):
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR return self.role == TenantAccountRole.DATASET_OPERATOR
class TenantStatus(enum.StrEnum): class TenantStatus(enum.StrEnum):
@ -129,66 +192,6 @@ class TenantStatus(enum.StrEnum):
ARCHIVE = "archive" ARCHIVE = "archive"
class TenantAccountRole(enum.StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
NORMAL = "normal"
DATASET_OPERATOR = "dataset_operator"
@staticmethod
def is_valid_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_privileged_role(role: str) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
@staticmethod
def is_admin_role(role: str) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN
@staticmethod
def is_non_owner_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_editing_role(role: str) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod
def is_dataset_edit_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR,
}
class Tenant(Base): class Tenant(Base):
__tablename__ = "tenants" __tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)

View File

@ -1,5 +1,7 @@
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import DeclarativeBase
from models.engine import metadata from models.engine import metadata
Base = declarative_base(metadata=metadata)
class Base(DeclarativeBase):
metadata = metadata

View File

@ -172,10 +172,6 @@ class WorkflowToolProvider(Base):
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
) )
@property
def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property @property
def user(self) -> Account | None: def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).filter(Account.id == self.user_id).first()

View File

@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Self, Union from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4 from uuid import uuid4
from core.variables import utils as variable_utils from core.variables import utils as variable_utils
@ -150,7 +150,7 @@ class Workflow(Base):
conversation_variables: Sequence[Variable], conversation_variables: Sequence[Variable],
marked_name: str = "", marked_name: str = "",
marked_comment: str = "", marked_comment: str = "",
) -> Self: ) -> "Workflow":
workflow = Workflow() workflow = Workflow()
workflow.id = str(uuid4()) workflow.id = str(uuid4())
workflow.tenant_id = tenant_id workflow.tenant_id = tenant_id

View File

@ -23,11 +23,10 @@ class VectorService:
): ):
documents: list[Document] = [] documents: list[Document] = []
document: Document | None = None
for segment in segments: for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX: if doc_form == IndexType.PARENT_CHILD_INDEX:
document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not document: if not dataset_document:
_logger.warning( _logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
segment.document_id, segment.document_id,
@ -37,7 +36,7 @@ class VectorService:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id) .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -61,9 +60,11 @@ class VectorService:
) )
else: else:
raise ValueError("The knowledge base index technique is not high quality!") raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) cls.generate_child_chunks(
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
)
else: else:
document = Document( rag_document = Document(
page_content=segment.content, page_content=segment.content,
metadata={ metadata={
"doc_id": segment.index_node_id, "doc_id": segment.index_node_id,
@ -72,7 +73,7 @@ class VectorService:
"dataset_id": segment.dataset_id, "dataset_id": segment.dataset_id,
}, },
) )
documents.append(document) documents.append(rag_document)
if len(documents) > 0: if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)

View File

@ -508,11 +508,11 @@ class WorkflowService:
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
# Check if this workflow is currently referenced by an app # Check if this workflow is currently referenced by an app
stmt = select(App).where(App.workflow_id == workflow_id) app_stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(stmt) app = session.scalar(app_stmt)
if app: if app:
# Cannot delete a workflow that's currently in use by an app # Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
# Don't use workflow.tool_published as it's not accurate for specific workflow versions # Don't use workflow.tool_published as it's not accurate for specific workflow versions
# Check if there's a tool provider using this specific workflow version # Check if there's a tool provider using this specific workflow version

View File

@ -111,7 +111,7 @@ def add_document_to_index_task(dataset_document_id: str):
logging.exception("add document to index failed") logging.exception("add document to index failed")
dataset_document.enabled = False dataset_document.enabled = False
dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
dataset_document.status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
db.session.commit() db.session.commit()
finally: finally:

View File

@ -193,7 +193,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
# Get app's owner # Get app's owner
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id) stmt = select(Account).where(Account.id == App.created_by).where(App.id == app_id)
user = session.scalar(stmt) user = session.scalar(stmt)
if user is None: if user is None:

View File

@ -34,13 +34,13 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
# needs to patch those methods to avoid database access. # needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
# replace `WorkflowAppGenerator.generate` 's return value. # replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr( monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"error": "oops"}}, lambda *args, **kwargs: {"data": {"error": "oops"}},
) )
monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None)
with pytest.raises(ToolInvokeError) as exc_info: with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to # WorkflowTool always returns a generator, so we need to iterate to