mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-04 11:14:10 +08:00
refactor: Use typed SQLAlchemy base model and fix type errors (#19980)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
ef3569e667
commit
3196dc2d61
@ -202,18 +202,18 @@ class EmailCodeLoginApi(Resource):
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenant = TenantService.get_join_tenants(account)
|
||||
if not tenant:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if not tenants:
|
||||
workspaces = FeatureService.get_system_features().license.workspaces
|
||||
if not workspaces.is_available():
|
||||
raise WorkspacesLimitExceeded()
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
else:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(new_tenant, account, role="owner")
|
||||
account.current_tenant = new_tenant
|
||||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if account is None:
|
||||
try:
|
||||
|
@ -148,15 +148,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
|
||||
if account:
|
||||
tenant = TenantService.get_join_tenants(account)
|
||||
if not tenant:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if not tenants:
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise WorkSpaceNotAllowedCreateError()
|
||||
else:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(new_tenant, account, role="owner")
|
||||
account.current_tenant = new_tenant
|
||||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
|
@ -540,9 +540,22 @@ class DatasetIndexingStatusApi(Resource):
|
||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
"splitting_completed_at": document.splitting_completed_at,
|
||||
"completed_at": document.completed_at,
|
||||
"paused_at": document.paused_at,
|
||||
"error": document.error,
|
||||
"stopped_at": document.stopped_at,
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
|
||||
|
@ -583,11 +583,22 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = "paused"
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
"splitting_completed_at": document.splitting_completed_at,
|
||||
"completed_at": document.completed_at,
|
||||
"paused_at": document.paused_at,
|
||||
"error": document.error,
|
||||
"stopped_at": document.stopped_at,
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
|
||||
@ -616,11 +627,22 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
.count()
|
||||
)
|
||||
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = "paused"
|
||||
return marshal(document, document_status_fields)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
"splitting_completed_at": document.splitting_completed_at,
|
||||
"completed_at": document.completed_at,
|
||||
"paused_at": document.paused_at,
|
||||
"error": document.error,
|
||||
"stopped_at": document.stopped_at,
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
return marshal(document_dict, document_status_fields)
|
||||
|
||||
|
||||
class DocumentDetailApi(DocumentResource):
|
||||
|
@ -68,16 +68,24 @@ class TenantListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
|
||||
for tenant in tenants:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
if features.billing.enabled:
|
||||
tenant.plan = features.billing.subscription.plan
|
||||
else:
|
||||
tenant.plan = "sandbox"
|
||||
if tenant.id == current_user.current_tenant_id:
|
||||
tenant.current = True # Set current=True for current tenant
|
||||
return {"workspaces": marshal(tenants, tenants_fields)}, 200
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
tenant_dict = {
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||
"current": tenant.id == current_user.current_tenant_id,
|
||||
}
|
||||
|
||||
tenant_dicts.append(tenant_dict)
|
||||
|
||||
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
|
||||
|
||||
|
||||
class WorkspaceListApi(Resource):
|
||||
|
@ -64,9 +64,24 @@ class PluginUploadFileApi(Resource):
|
||||
|
||||
extension = guess_extension(tool_file.mimetype) or ".bin"
|
||||
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
|
||||
tool_file.mime_type = mimetype
|
||||
tool_file.extension = extension
|
||||
tool_file.preview_url = preview_url
|
||||
|
||||
# Create a dictionary with all the necessary attributes
|
||||
result = {
|
||||
"id": tool_file.id,
|
||||
"user_id": tool_file.user_id,
|
||||
"tenant_id": tool_file.tenant_id,
|
||||
"conversation_id": tool_file.conversation_id,
|
||||
"file_key": tool_file.file_key,
|
||||
"mimetype": tool_file.mimetype,
|
||||
"original_url": tool_file.original_url,
|
||||
"name": tool_file.name,
|
||||
"size": tool_file.size,
|
||||
"mime_type": mimetype,
|
||||
"extension": extension,
|
||||
"preview_url": preview_url,
|
||||
}
|
||||
|
||||
return result, 201
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
@ -388,11 +388,22 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = "paused"
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
"splitting_completed_at": document.splitting_completed_at,
|
||||
"completed_at": document.completed_at,
|
||||
"paused_at": document.paused_at,
|
||||
"error": document.error,
|
||||
"stopped_at": document.stopped_at,
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
|
||||
|
@ -405,7 +405,29 @@ class RetrievalService:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
|
||||
return [RetrievalSegments(**record) for record in records]
|
||||
result = []
|
||||
for record in records:
|
||||
# Extract segment
|
||||
segment = record["segment"]
|
||||
|
||||
# Extract child_chunks, ensuring it's a list or None
|
||||
child_chunks = record.get("child_chunks")
|
||||
if not isinstance(child_chunks, list):
|
||||
child_chunks = None
|
||||
|
||||
# Extract score, ensuring it's a float or None
|
||||
score_value = record.get("score")
|
||||
score = (
|
||||
float(score_value)
|
||||
if score_value is not None and isinstance(score_value, int | float | str)
|
||||
else None
|
||||
)
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
|
@ -528,7 +528,7 @@ class ToolManager:
|
||||
yield provider
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"load builtin provider {provider}")
|
||||
logger.exception(f"load builtin provider {provider_path}")
|
||||
continue
|
||||
# set builtin providers loaded
|
||||
cls._builtin_providers_loaded = True
|
||||
@ -644,10 +644,10 @@ class ToolManager:
|
||||
)
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
for workflow_provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
|
@ -1,7 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
@ -87,7 +89,7 @@ class WorkflowTool(Tool):
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=self._get_user(user_id),
|
||||
user=cast("Account | EndUser", current_user),
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
@ -111,20 +113,6 @@ class WorkflowTool(Tool):
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
yield self.create_json_message(outputs)
|
||||
|
||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user:
|
||||
user = db.session.query(Account).filter(Account.id == user_id).first()
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
|
@ -3,11 +3,14 @@ import json
|
||||
import flask_login # type: ignore
|
||||
from flask import Response, request
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
import contexts
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
|
||||
login_manager = flask_login.LoginManager()
|
||||
@ -17,34 +20,48 @@ login_manager = flask_login.LoginManager()
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in {"console", "inner_api"}:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header:
|
||||
auth_token = request.args.get("_token")
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
else:
|
||||
auth_token: str | None = None
|
||||
if auth_header:
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme, auth_token = auth_header.split(maxsplit=1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
else:
|
||||
auth_token = request.args.get("_token")
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
if request.blueprint in {"console", "inner_api"}:
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "web":
|
||||
decoded = PassportService().verify(auth_token)
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
if not end_user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
|
||||
if not end_user:
|
||||
raise NotFound("End user not found.")
|
||||
return end_user
|
||||
|
||||
|
||||
@user_logged_in.connect
|
||||
@user_loaded_from_request.connect
|
||||
def on_user_logged_in(_sender, user):
|
||||
"""Called when a user logged in."""
|
||||
if user:
|
||||
"""Called when a user logged in.
|
||||
|
||||
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
||||
through the load_user method, which calls account.set_tenant_id().
|
||||
"""
|
||||
if user and isinstance(user, Account) and user.current_tenant_id:
|
||||
contexts.tenant_id.set(user.current_tenant_id)
|
||||
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import enum
|
||||
import json
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
|
||||
|
||||
from models.base import Base
|
||||
|
||||
@ -12,6 +12,66 @@ from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class TenantAccountRole(enum.StrEnum):
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
EDITOR = "editor"
|
||||
NORMAL = "normal"
|
||||
DATASET_OPERATOR = "dataset_operator"
|
||||
|
||||
@staticmethod
|
||||
def is_valid_role(role: str) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.NORMAL,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
|
||||
|
||||
@staticmethod
|
||||
def is_admin_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role == TenantAccountRole.ADMIN
|
||||
|
||||
@staticmethod
|
||||
def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.NORMAL,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def is_editing_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
|
||||
|
||||
@staticmethod
|
||||
def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
}
|
||||
|
||||
|
||||
class AccountStatus(enum.StrEnum):
|
||||
PENDING = "pending"
|
||||
UNINITIALIZED = "uninitialized"
|
||||
@ -41,24 +101,27 @@ class Account(UserMixin, Base):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@reconstructor
|
||||
def init_on_load(self):
|
||||
self.role: Optional[TenantAccountRole] = None
|
||||
self._current_tenant: Optional[Tenant] = None
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
return self.password is not None
|
||||
|
||||
@property
|
||||
def current_tenant(self):
|
||||
return self._current_tenant # type: ignore
|
||||
return self._current_tenant
|
||||
|
||||
@current_tenant.setter
|
||||
def current_tenant(self, value: "Tenant"):
|
||||
tenant = value
|
||||
def current_tenant(self, tenant: "Tenant"):
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
|
||||
if ta:
|
||||
tenant.current_role = ta.role
|
||||
else:
|
||||
tenant = None # type: ignore
|
||||
|
||||
self._current_tenant = tenant
|
||||
self.role = TenantAccountRole(ta.role)
|
||||
self._current_tenant = tenant
|
||||
return
|
||||
self._current_tenant = None
|
||||
|
||||
@property
|
||||
def current_tenant_id(self) -> str | None:
|
||||
@ -80,12 +143,12 @@ class Account(UserMixin, Base):
|
||||
return
|
||||
|
||||
tenant, join = tenant_account_join
|
||||
tenant.current_role = join.role
|
||||
self.role = join.role
|
||||
self._current_tenant = tenant
|
||||
|
||||
@property
|
||||
def current_role(self):
|
||||
return self._current_tenant.current_role
|
||||
return self.role
|
||||
|
||||
def get_status(self) -> AccountStatus:
|
||||
status_str = self.status
|
||||
@ -105,23 +168,23 @@ class Account(UserMixin, Base):
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
@property
|
||||
def is_admin_or_owner(self):
|
||||
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
|
||||
return TenantAccountRole.is_privileged_role(self.role)
|
||||
|
||||
@property
|
||||
def is_admin(self):
|
||||
return TenantAccountRole.is_admin_role(self._current_tenant.current_role)
|
||||
return TenantAccountRole.is_admin_role(self.role)
|
||||
|
||||
@property
|
||||
def is_editor(self):
|
||||
return TenantAccountRole.is_editing_role(self._current_tenant.current_role)
|
||||
return TenantAccountRole.is_editing_role(self.role)
|
||||
|
||||
@property
|
||||
def is_dataset_editor(self):
|
||||
return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role)
|
||||
return TenantAccountRole.is_dataset_edit_role(self.role)
|
||||
|
||||
@property
|
||||
def is_dataset_operator(self):
|
||||
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR
|
||||
return self.role == TenantAccountRole.DATASET_OPERATOR
|
||||
|
||||
|
||||
class TenantStatus(enum.StrEnum):
|
||||
@ -129,66 +192,6 @@ class TenantStatus(enum.StrEnum):
|
||||
ARCHIVE = "archive"
|
||||
|
||||
|
||||
class TenantAccountRole(enum.StrEnum):
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
EDITOR = "editor"
|
||||
NORMAL = "normal"
|
||||
DATASET_OPERATOR = "dataset_operator"
|
||||
|
||||
@staticmethod
|
||||
def is_valid_role(role: str) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.NORMAL,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def is_privileged_role(role: str) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
|
||||
|
||||
@staticmethod
|
||||
def is_admin_role(role: str) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role == TenantAccountRole.ADMIN
|
||||
|
||||
@staticmethod
|
||||
def is_non_owner_role(role: str) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.NORMAL,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def is_editing_role(role: str) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
|
||||
|
||||
@staticmethod
|
||||
def is_dataset_edit_role(role: str) -> bool:
|
||||
if not role:
|
||||
return False
|
||||
return role in {
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
}
|
||||
|
||||
|
||||
class Tenant(Base):
|
||||
__tablename__ = "tenants"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
@ -1,5 +1,7 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from models.engine import metadata
|
||||
|
||||
Base = declarative_base(metadata=metadata)
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
metadata = metadata
|
||||
|
@ -172,10 +172,6 @@ class WorkflowToolProvider(Base):
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
@property
|
||||
def schema_type(self) -> ApiProviderSchemaType:
|
||||
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from core.variables import utils as variable_utils
|
||||
@ -150,7 +150,7 @@ class Workflow(Base):
|
||||
conversation_variables: Sequence[Variable],
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> Self:
|
||||
) -> "Workflow":
|
||||
workflow = Workflow()
|
||||
workflow.id = str(uuid4())
|
||||
workflow.tenant_id = tenant_id
|
||||
|
@ -23,11 +23,10 @@ class VectorService:
|
||||
):
|
||||
documents: list[Document] = []
|
||||
|
||||
document: Document | None = None
|
||||
for segment in segments:
|
||||
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||
if not document:
|
||||
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||
if not dataset_document:
|
||||
_logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
|
||||
segment.document_id,
|
||||
@ -37,7 +36,7 @@ class VectorService:
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
@ -61,9 +60,11 @@ class VectorService:
|
||||
)
|
||||
else:
|
||||
raise ValueError("The knowledge base index technique is not high quality!")
|
||||
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
|
||||
cls.generate_child_chunks(
|
||||
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
|
||||
)
|
||||
else:
|
||||
document = Document(
|
||||
rag_document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
@ -72,7 +73,7 @@ class VectorService:
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
documents.append(rag_document)
|
||||
if len(documents) > 0:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
|
||||
|
@ -508,11 +508,11 @@ class WorkflowService:
|
||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||
|
||||
# Check if this workflow is currently referenced by an app
|
||||
stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
app_stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if app:
|
||||
# Cannot delete a workflow that's currently in use by an app
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
|
||||
|
||||
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
|
||||
# Check if there's a tool provider using this specific workflow version
|
||||
|
@ -111,7 +111,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
logging.exception("add document to index failed")
|
||||
dataset_document.enabled = False
|
||||
dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
dataset_document.status = "error"
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
db.session.commit()
|
||||
finally:
|
||||
|
@ -193,7 +193,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||
# Get app's owner
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id)
|
||||
stmt = select(Account).where(Account.id == App.created_by).where(App.id == app_id)
|
||||
user = session.scalar(stmt)
|
||||
|
||||
if user is None:
|
||||
|
@ -34,13 +34,13 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
# needs to patch those methods to avoid database access.
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
|
||||
|
||||
# replace `WorkflowAppGenerator.generate` 's return value.
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||
)
|
||||
monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
with pytest.raises(ToolInvokeError) as exc_info:
|
||||
# WorkflowTool always returns a generator, so we need to iterate to
|
||||
|
Loading…
x
Reference in New Issue
Block a user