mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-04 11:14:10 +08:00
refactor: Use typed SQLAlchemy base model and fix type errors (#19980)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
ef3569e667
commit
3196dc2d61
@ -202,18 +202,18 @@ class EmailCodeLoginApi(Resource):
|
|||||||
except AccountRegisterError as are:
|
except AccountRegisterError as are:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
if account:
|
if account:
|
||||||
tenant = TenantService.get_join_tenants(account)
|
tenants = TenantService.get_join_tenants(account)
|
||||||
if not tenant:
|
if not tenants:
|
||||||
workspaces = FeatureService.get_system_features().license.workspaces
|
workspaces = FeatureService.get_system_features().license.workspaces
|
||||||
if not workspaces.is_available():
|
if not workspaces.is_available():
|
||||||
raise WorkspacesLimitExceeded()
|
raise WorkspacesLimitExceeded()
|
||||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||||
raise NotAllowedCreateWorkspace()
|
raise NotAllowedCreateWorkspace()
|
||||||
else:
|
else:
|
||||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
TenantService.create_tenant_member(new_tenant, account, role="owner")
|
||||||
account.current_tenant = tenant
|
account.current_tenant = new_tenant
|
||||||
tenant_was_created.send(tenant)
|
tenant_was_created.send(new_tenant)
|
||||||
|
|
||||||
if account is None:
|
if account is None:
|
||||||
try:
|
try:
|
||||||
|
@ -148,15 +148,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||||||
account = _get_account_by_openid_or_email(provider, user_info)
|
account = _get_account_by_openid_or_email(provider, user_info)
|
||||||
|
|
||||||
if account:
|
if account:
|
||||||
tenant = TenantService.get_join_tenants(account)
|
tenants = TenantService.get_join_tenants(account)
|
||||||
if not tenant:
|
if not tenants:
|
||||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||||
raise WorkSpaceNotAllowedCreateError()
|
raise WorkSpaceNotAllowedCreateError()
|
||||||
else:
|
else:
|
||||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
TenantService.create_tenant_member(new_tenant, account, role="owner")
|
||||||
account.current_tenant = tenant
|
account.current_tenant = new_tenant
|
||||||
tenant_was_created.send(tenant)
|
tenant_was_created.send(new_tenant)
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
if not FeatureService.get_system_features().is_allow_register:
|
if not FeatureService.get_system_features().is_allow_register:
|
||||||
|
@ -540,9 +540,22 @@ class DatasetIndexingStatusApi(Resource):
|
|||||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
document.completed_segments = completed_segments
|
# Create a dictionary with document attributes and additional fields
|
||||||
document.total_segments = total_segments
|
document_dict = {
|
||||||
documents_status.append(marshal(document, document_status_fields))
|
"id": document.id,
|
||||||
|
"indexing_status": document.indexing_status,
|
||||||
|
"processing_started_at": document.processing_started_at,
|
||||||
|
"parsing_completed_at": document.parsing_completed_at,
|
||||||
|
"cleaning_completed_at": document.cleaning_completed_at,
|
||||||
|
"splitting_completed_at": document.splitting_completed_at,
|
||||||
|
"completed_at": document.completed_at,
|
||||||
|
"paused_at": document.paused_at,
|
||||||
|
"error": document.error,
|
||||||
|
"stopped_at": document.stopped_at,
|
||||||
|
"completed_segments": completed_segments,
|
||||||
|
"total_segments": total_segments,
|
||||||
|
}
|
||||||
|
documents_status.append(marshal(document_dict, document_status_fields))
|
||||||
data = {"data": documents_status}
|
data = {"data": documents_status}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -583,11 +583,22 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
document.completed_segments = completed_segments
|
# Create a dictionary with document attributes and additional fields
|
||||||
document.total_segments = total_segments
|
document_dict = {
|
||||||
if document.is_paused:
|
"id": document.id,
|
||||||
document.indexing_status = "paused"
|
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||||
documents_status.append(marshal(document, document_status_fields))
|
"processing_started_at": document.processing_started_at,
|
||||||
|
"parsing_completed_at": document.parsing_completed_at,
|
||||||
|
"cleaning_completed_at": document.cleaning_completed_at,
|
||||||
|
"splitting_completed_at": document.splitting_completed_at,
|
||||||
|
"completed_at": document.completed_at,
|
||||||
|
"paused_at": document.paused_at,
|
||||||
|
"error": document.error,
|
||||||
|
"stopped_at": document.stopped_at,
|
||||||
|
"completed_segments": completed_segments,
|
||||||
|
"total_segments": total_segments,
|
||||||
|
}
|
||||||
|
documents_status.append(marshal(document_dict, document_status_fields))
|
||||||
data = {"data": documents_status}
|
data = {"data": documents_status}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@ -616,11 +627,22 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
|
|
||||||
document.completed_segments = completed_segments
|
# Create a dictionary with document attributes and additional fields
|
||||||
document.total_segments = total_segments
|
document_dict = {
|
||||||
if document.is_paused:
|
"id": document.id,
|
||||||
document.indexing_status = "paused"
|
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||||
return marshal(document, document_status_fields)
|
"processing_started_at": document.processing_started_at,
|
||||||
|
"parsing_completed_at": document.parsing_completed_at,
|
||||||
|
"cleaning_completed_at": document.cleaning_completed_at,
|
||||||
|
"splitting_completed_at": document.splitting_completed_at,
|
||||||
|
"completed_at": document.completed_at,
|
||||||
|
"paused_at": document.paused_at,
|
||||||
|
"error": document.error,
|
||||||
|
"stopped_at": document.stopped_at,
|
||||||
|
"completed_segments": completed_segments,
|
||||||
|
"total_segments": total_segments,
|
||||||
|
}
|
||||||
|
return marshal(document_dict, document_status_fields)
|
||||||
|
|
||||||
|
|
||||||
class DocumentDetailApi(DocumentResource):
|
class DocumentDetailApi(DocumentResource):
|
||||||
|
@ -68,16 +68,24 @@ class TenantListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
tenants = TenantService.get_join_tenants(current_user)
|
tenants = TenantService.get_join_tenants(current_user)
|
||||||
|
tenant_dicts = []
|
||||||
|
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
features = FeatureService.get_features(tenant.id)
|
features = FeatureService.get_features(tenant.id)
|
||||||
if features.billing.enabled:
|
|
||||||
tenant.plan = features.billing.subscription.plan
|
# Create a dictionary with tenant attributes
|
||||||
else:
|
tenant_dict = {
|
||||||
tenant.plan = "sandbox"
|
"id": tenant.id,
|
||||||
if tenant.id == current_user.current_tenant_id:
|
"name": tenant.name,
|
||||||
tenant.current = True # Set current=True for current tenant
|
"status": tenant.status,
|
||||||
return {"workspaces": marshal(tenants, tenants_fields)}, 200
|
"created_at": tenant.created_at,
|
||||||
|
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||||
|
"current": tenant.id == current_user.current_tenant_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
tenant_dicts.append(tenant_dict)
|
||||||
|
|
||||||
|
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceListApi(Resource):
|
class WorkspaceListApi(Resource):
|
||||||
|
@ -64,9 +64,24 @@ class PluginUploadFileApi(Resource):
|
|||||||
|
|
||||||
extension = guess_extension(tool_file.mimetype) or ".bin"
|
extension = guess_extension(tool_file.mimetype) or ".bin"
|
||||||
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
|
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
|
||||||
tool_file.mime_type = mimetype
|
|
||||||
tool_file.extension = extension
|
# Create a dictionary with all the necessary attributes
|
||||||
tool_file.preview_url = preview_url
|
result = {
|
||||||
|
"id": tool_file.id,
|
||||||
|
"user_id": tool_file.user_id,
|
||||||
|
"tenant_id": tool_file.tenant_id,
|
||||||
|
"conversation_id": tool_file.conversation_id,
|
||||||
|
"file_key": tool_file.file_key,
|
||||||
|
"mimetype": tool_file.mimetype,
|
||||||
|
"original_url": tool_file.original_url,
|
||||||
|
"name": tool_file.name,
|
||||||
|
"size": tool_file.size,
|
||||||
|
"mime_type": mimetype,
|
||||||
|
"extension": extension,
|
||||||
|
"preview_url": preview_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, 201
|
||||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
@ -388,11 +388,22 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
document.completed_segments = completed_segments
|
# Create a dictionary with document attributes and additional fields
|
||||||
document.total_segments = total_segments
|
document_dict = {
|
||||||
if document.is_paused:
|
"id": document.id,
|
||||||
document.indexing_status = "paused"
|
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||||
documents_status.append(marshal(document, document_status_fields))
|
"processing_started_at": document.processing_started_at,
|
||||||
|
"parsing_completed_at": document.parsing_completed_at,
|
||||||
|
"cleaning_completed_at": document.cleaning_completed_at,
|
||||||
|
"splitting_completed_at": document.splitting_completed_at,
|
||||||
|
"completed_at": document.completed_at,
|
||||||
|
"paused_at": document.paused_at,
|
||||||
|
"error": document.error,
|
||||||
|
"stopped_at": document.stopped_at,
|
||||||
|
"completed_segments": completed_segments,
|
||||||
|
"total_segments": total_segments,
|
||||||
|
}
|
||||||
|
documents_status.append(marshal(document_dict, document_status_fields))
|
||||||
data = {"data": documents_status}
|
data = {"data": documents_status}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -405,7 +405,29 @@ class RetrievalService:
|
|||||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||||
|
|
||||||
return [RetrievalSegments(**record) for record in records]
|
result = []
|
||||||
|
for record in records:
|
||||||
|
# Extract segment
|
||||||
|
segment = record["segment"]
|
||||||
|
|
||||||
|
# Extract child_chunks, ensuring it's a list or None
|
||||||
|
child_chunks = record.get("child_chunks")
|
||||||
|
if not isinstance(child_chunks, list):
|
||||||
|
child_chunks = None
|
||||||
|
|
||||||
|
# Extract score, ensuring it's a float or None
|
||||||
|
score_value = record.get("score")
|
||||||
|
score = (
|
||||||
|
float(score_value)
|
||||||
|
if score_value is not None and isinstance(score_value, int | float | str)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create RetrievalSegments object
|
||||||
|
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
|
||||||
|
result.append(retrieval_segment)
|
||||||
|
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
@ -528,7 +528,7 @@ class ToolManager:
|
|||||||
yield provider
|
yield provider
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"load builtin provider {provider}")
|
logger.exception(f"load builtin provider {provider_path}")
|
||||||
continue
|
continue
|
||||||
# set builtin providers loaded
|
# set builtin providers loaded
|
||||||
cls._builtin_providers_loaded = True
|
cls._builtin_providers_loaded = True
|
||||||
@ -644,10 +644,10 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||||
for provider in workflow_providers:
|
for workflow_provider in workflow_providers:
|
||||||
try:
|
try:
|
||||||
workflow_provider_controllers.append(
|
workflow_provider_controllers.append(
|
||||||
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
|
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# app has been deleted
|
# app has been deleted
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
@ -87,7 +89,7 @@ class WorkflowTool(Tool):
|
|||||||
result = generator.generate(
|
result = generator.generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
user=self._get_user(user_id),
|
user=cast("Account | EndUser", current_user),
|
||||||
args={"inputs": tool_parameters, "files": files},
|
args={"inputs": tool_parameters, "files": files},
|
||||||
invoke_from=self.runtime.invoke_from,
|
invoke_from=self.runtime.invoke_from,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
@ -111,20 +113,6 @@ class WorkflowTool(Tool):
|
|||||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||||
yield self.create_json_message(outputs)
|
yield self.create_json_message(outputs)
|
||||||
|
|
||||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
|
||||||
"""
|
|
||||||
get the user by user id
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
|
||||||
if not user:
|
|
||||||
user = db.session.query(Account).filter(Account.id == user_id).first()
|
|
||||||
|
|
||||||
if not user:
|
|
||||||
raise ValueError("user not found")
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
|
@ -3,11 +3,14 @@ import json
|
|||||||
import flask_login # type: ignore
|
import flask_login # type: ignore
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_login import user_loaded_from_request, user_logged_in
|
from flask_login import user_loaded_from_request, user_logged_in
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
from extensions.ext_database import db
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
|
from models.account import Account
|
||||||
|
from models.model import EndUser
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
|
|
||||||
login_manager = flask_login.LoginManager()
|
login_manager = flask_login.LoginManager()
|
||||||
@ -17,34 +20,48 @@ login_manager = flask_login.LoginManager()
|
|||||||
@login_manager.request_loader
|
@login_manager.request_loader
|
||||||
def load_user_from_request(request_from_flask_login):
|
def load_user_from_request(request_from_flask_login):
|
||||||
"""Load user based on the request."""
|
"""Load user based on the request."""
|
||||||
if request.blueprint not in {"console", "inner_api"}:
|
|
||||||
return None
|
|
||||||
# Check if the user_id contains a dot, indicating the old format
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
auth_header = request.headers.get("Authorization", "")
|
||||||
if not auth_header:
|
auth_token: str | None = None
|
||||||
auth_token = request.args.get("_token")
|
if auth_header:
|
||||||
if not auth_token:
|
|
||||||
raise Unauthorized("Invalid Authorization token.")
|
|
||||||
else:
|
|
||||||
if " " not in auth_header:
|
if " " not in auth_header:
|
||||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
auth_scheme, auth_token = auth_header.split(maxsplit=1)
|
||||||
auth_scheme = auth_scheme.lower()
|
auth_scheme = auth_scheme.lower()
|
||||||
if auth_scheme != "bearer":
|
if auth_scheme != "bearer":
|
||||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
else:
|
||||||
|
auth_token = request.args.get("_token")
|
||||||
|
|
||||||
decoded = PassportService().verify(auth_token)
|
if request.blueprint in {"console", "inner_api"}:
|
||||||
user_id = decoded.get("user_id")
|
if not auth_token:
|
||||||
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
decoded = PassportService().verify(auth_token)
|
||||||
|
user_id = decoded.get("user_id")
|
||||||
|
if not user_id:
|
||||||
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
|
||||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
return logged_in_account
|
return logged_in_account
|
||||||
|
elif request.blueprint == "web":
|
||||||
|
decoded = PassportService().verify(auth_token)
|
||||||
|
end_user_id = decoded.get("end_user_id")
|
||||||
|
if not end_user_id:
|
||||||
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
|
||||||
|
if not end_user:
|
||||||
|
raise NotFound("End user not found.")
|
||||||
|
return end_user
|
||||||
|
|
||||||
|
|
||||||
@user_logged_in.connect
|
@user_logged_in.connect
|
||||||
@user_loaded_from_request.connect
|
@user_loaded_from_request.connect
|
||||||
def on_user_logged_in(_sender, user):
|
def on_user_logged_in(_sender, user):
|
||||||
"""Called when a user logged in."""
|
"""Called when a user logged in.
|
||||||
if user:
|
|
||||||
|
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
|
||||||
|
through the load_user method, which calls account.set_tenant_id().
|
||||||
|
"""
|
||||||
|
if user and isinstance(user, Account) and user.current_tenant_id:
|
||||||
contexts.tenant_id.set(user.current_tenant_id)
|
contexts.tenant_id.set(user.current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
from typing import cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from flask_login import UserMixin # type: ignore
|
from flask_login import UserMixin # type: ignore
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
|
||||||
|
|
||||||
from models.base import Base
|
from models.base import Base
|
||||||
|
|
||||||
@ -12,6 +12,66 @@ from .engine import db
|
|||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
|
|
||||||
|
class TenantAccountRole(enum.StrEnum):
|
||||||
|
OWNER = "owner"
|
||||||
|
ADMIN = "admin"
|
||||||
|
EDITOR = "editor"
|
||||||
|
NORMAL = "normal"
|
||||||
|
DATASET_OPERATOR = "dataset_operator"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_role(role: str) -> bool:
|
||||||
|
if not role:
|
||||||
|
return False
|
||||||
|
return role in {
|
||||||
|
TenantAccountRole.OWNER,
|
||||||
|
TenantAccountRole.ADMIN,
|
||||||
|
TenantAccountRole.EDITOR,
|
||||||
|
TenantAccountRole.NORMAL,
|
||||||
|
TenantAccountRole.DATASET_OPERATOR,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||||
|
if not role:
|
||||||
|
return False
|
||||||
|
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_admin_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||||
|
if not role:
|
||||||
|
return False
|
||||||
|
return role == TenantAccountRole.ADMIN
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||||
|
if not role:
|
||||||
|
return False
|
||||||
|
return role in {
|
||||||
|
TenantAccountRole.ADMIN,
|
||||||
|
TenantAccountRole.EDITOR,
|
||||||
|
TenantAccountRole.NORMAL,
|
||||||
|
TenantAccountRole.DATASET_OPERATOR,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_editing_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||||
|
if not role:
|
||||||
|
return False
|
||||||
|
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool:
|
||||||
|
if not role:
|
||||||
|
return False
|
||||||
|
return role in {
|
||||||
|
TenantAccountRole.OWNER,
|
||||||
|
TenantAccountRole.ADMIN,
|
||||||
|
TenantAccountRole.EDITOR,
|
||||||
|
TenantAccountRole.DATASET_OPERATOR,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AccountStatus(enum.StrEnum):
|
class AccountStatus(enum.StrEnum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
UNINITIALIZED = "uninitialized"
|
UNINITIALIZED = "uninitialized"
|
||||||
@ -41,24 +101,27 @@ class Account(UserMixin, Base):
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
|
@reconstructor
|
||||||
|
def init_on_load(self):
|
||||||
|
self.role: Optional[TenantAccountRole] = None
|
||||||
|
self._current_tenant: Optional[Tenant] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_password_set(self):
|
def is_password_set(self):
|
||||||
return self.password is not None
|
return self.password is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_tenant(self):
|
def current_tenant(self):
|
||||||
return self._current_tenant # type: ignore
|
return self._current_tenant
|
||||||
|
|
||||||
@current_tenant.setter
|
@current_tenant.setter
|
||||||
def current_tenant(self, value: "Tenant"):
|
def current_tenant(self, tenant: "Tenant"):
|
||||||
tenant = value
|
|
||||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
|
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
|
||||||
if ta:
|
if ta:
|
||||||
tenant.current_role = ta.role
|
self.role = TenantAccountRole(ta.role)
|
||||||
else:
|
self._current_tenant = tenant
|
||||||
tenant = None # type: ignore
|
return
|
||||||
|
self._current_tenant = None
|
||||||
self._current_tenant = tenant
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_tenant_id(self) -> str | None:
|
def current_tenant_id(self) -> str | None:
|
||||||
@ -80,12 +143,12 @@ class Account(UserMixin, Base):
|
|||||||
return
|
return
|
||||||
|
|
||||||
tenant, join = tenant_account_join
|
tenant, join = tenant_account_join
|
||||||
tenant.current_role = join.role
|
self.role = join.role
|
||||||
self._current_tenant = tenant
|
self._current_tenant = tenant
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_role(self):
|
def current_role(self):
|
||||||
return self._current_tenant.current_role
|
return self.role
|
||||||
|
|
||||||
def get_status(self) -> AccountStatus:
|
def get_status(self) -> AccountStatus:
|
||||||
status_str = self.status
|
status_str = self.status
|
||||||
@ -105,23 +168,23 @@ class Account(UserMixin, Base):
|
|||||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||||
@property
|
@property
|
||||||
def is_admin_or_owner(self):
|
def is_admin_or_owner(self):
|
||||||
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
|
return TenantAccountRole.is_privileged_role(self.role)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_admin(self):
|
def is_admin(self):
|
||||||
return TenantAccountRole.is_admin_role(self._current_tenant.current_role)
|
return TenantAccountRole.is_admin_role(self.role)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_editor(self):
|
def is_editor(self):
|
||||||
return TenantAccountRole.is_editing_role(self._current_tenant.current_role)
|
return TenantAccountRole.is_editing_role(self.role)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_dataset_editor(self):
|
def is_dataset_editor(self):
|
||||||
return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role)
|
return TenantAccountRole.is_dataset_edit_role(self.role)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_dataset_operator(self):
|
def is_dataset_operator(self):
|
||||||
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR
|
return self.role == TenantAccountRole.DATASET_OPERATOR
|
||||||
|
|
||||||
|
|
||||||
class TenantStatus(enum.StrEnum):
|
class TenantStatus(enum.StrEnum):
|
||||||
@ -129,66 +192,6 @@ class TenantStatus(enum.StrEnum):
|
|||||||
ARCHIVE = "archive"
|
ARCHIVE = "archive"
|
||||||
|
|
||||||
|
|
||||||
class TenantAccountRole(enum.StrEnum):
|
|
||||||
OWNER = "owner"
|
|
||||||
ADMIN = "admin"
|
|
||||||
EDITOR = "editor"
|
|
||||||
NORMAL = "normal"
|
|
||||||
DATASET_OPERATOR = "dataset_operator"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_valid_role(role: str) -> bool:
|
|
||||||
if not role:
|
|
||||||
return False
|
|
||||||
return role in {
|
|
||||||
TenantAccountRole.OWNER,
|
|
||||||
TenantAccountRole.ADMIN,
|
|
||||||
TenantAccountRole.EDITOR,
|
|
||||||
TenantAccountRole.NORMAL,
|
|
||||||
TenantAccountRole.DATASET_OPERATOR,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_privileged_role(role: str) -> bool:
|
|
||||||
if not role:
|
|
||||||
return False
|
|
||||||
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_admin_role(role: str) -> bool:
|
|
||||||
if not role:
|
|
||||||
return False
|
|
||||||
return role == TenantAccountRole.ADMIN
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_non_owner_role(role: str) -> bool:
|
|
||||||
if not role:
|
|
||||||
return False
|
|
||||||
return role in {
|
|
||||||
TenantAccountRole.ADMIN,
|
|
||||||
TenantAccountRole.EDITOR,
|
|
||||||
TenantAccountRole.NORMAL,
|
|
||||||
TenantAccountRole.DATASET_OPERATOR,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_editing_role(role: str) -> bool:
|
|
||||||
if not role:
|
|
||||||
return False
|
|
||||||
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_dataset_edit_role(role: str) -> bool:
|
|
||||||
if not role:
|
|
||||||
return False
|
|
||||||
return role in {
|
|
||||||
TenantAccountRole.OWNER,
|
|
||||||
TenantAccountRole.ADMIN,
|
|
||||||
TenantAccountRole.EDITOR,
|
|
||||||
TenantAccountRole.DATASET_OPERATOR,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Tenant(Base):
|
class Tenant(Base):
|
||||||
__tablename__ = "tenants"
|
__tablename__ = "tenants"
|
||||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from sqlalchemy.orm import declarative_base
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from models.engine import metadata
|
from models.engine import metadata
|
||||||
|
|
||||||
Base = declarative_base(metadata=metadata)
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
metadata = metadata
|
||||||
|
@ -172,10 +172,6 @@ class WorkflowToolProvider(Base):
|
|||||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def schema_type(self) -> ApiProviderSchemaType:
|
|
||||||
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user(self) -> Account | None:
|
def user(self) -> Account | None:
|
||||||
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from core.variables import utils as variable_utils
|
from core.variables import utils as variable_utils
|
||||||
@ -150,7 +150,7 @@ class Workflow(Base):
|
|||||||
conversation_variables: Sequence[Variable],
|
conversation_variables: Sequence[Variable],
|
||||||
marked_name: str = "",
|
marked_name: str = "",
|
||||||
marked_comment: str = "",
|
marked_comment: str = "",
|
||||||
) -> Self:
|
) -> "Workflow":
|
||||||
workflow = Workflow()
|
workflow = Workflow()
|
||||||
workflow.id = str(uuid4())
|
workflow.id = str(uuid4())
|
||||||
workflow.tenant_id = tenant_id
|
workflow.tenant_id = tenant_id
|
||||||
|
@ -23,11 +23,10 @@ class VectorService:
|
|||||||
):
|
):
|
||||||
documents: list[Document] = []
|
documents: list[Document] = []
|
||||||
|
|
||||||
document: Document | None = None
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||||
document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||||
if not document:
|
if not dataset_document:
|
||||||
_logger.warning(
|
_logger.warning(
|
||||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
|
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
|
||||||
segment.document_id,
|
segment.document_id,
|
||||||
@ -37,7 +36,7 @@ class VectorService:
|
|||||||
# get the process rule
|
# get the process rule
|
||||||
processing_rule = (
|
processing_rule = (
|
||||||
db.session.query(DatasetProcessRule)
|
db.session.query(DatasetProcessRule)
|
||||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not processing_rule:
|
if not processing_rule:
|
||||||
@ -61,9 +60,11 @@ class VectorService:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("The knowledge base index technique is not high quality!")
|
raise ValueError("The knowledge base index technique is not high quality!")
|
||||||
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
|
cls.generate_child_chunks(
|
||||||
|
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
document = Document(
|
rag_document = Document(
|
||||||
page_content=segment.content,
|
page_content=segment.content,
|
||||||
metadata={
|
metadata={
|
||||||
"doc_id": segment.index_node_id,
|
"doc_id": segment.index_node_id,
|
||||||
@ -72,7 +73,7 @@ class VectorService:
|
|||||||
"dataset_id": segment.dataset_id,
|
"dataset_id": segment.dataset_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
documents.append(document)
|
documents.append(rag_document)
|
||||||
if len(documents) > 0:
|
if len(documents) > 0:
|
||||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||||
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
|
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
|
||||||
|
@ -508,11 +508,11 @@ class WorkflowService:
|
|||||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||||
|
|
||||||
# Check if this workflow is currently referenced by an app
|
# Check if this workflow is currently referenced by an app
|
||||||
stmt = select(App).where(App.workflow_id == workflow_id)
|
app_stmt = select(App).where(App.workflow_id == workflow_id)
|
||||||
app = session.scalar(stmt)
|
app = session.scalar(app_stmt)
|
||||||
if app:
|
if app:
|
||||||
# Cannot delete a workflow that's currently in use by an app
|
# Cannot delete a workflow that's currently in use by an app
|
||||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
|
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
|
||||||
|
|
||||||
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
|
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
|
||||||
# Check if there's a tool provider using this specific workflow version
|
# Check if there's a tool provider using this specific workflow version
|
||||||
|
@ -111,7 +111,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||||||
logging.exception("add document to index failed")
|
logging.exception("add document to index failed")
|
||||||
dataset_document.enabled = False
|
dataset_document.enabled = False
|
||||||
dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
dataset_document.status = "error"
|
dataset_document.indexing_status = "error"
|
||||||
dataset_document.error = str(e)
|
dataset_document.error = str(e)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
finally:
|
finally:
|
||||||
|
@ -193,7 +193,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
|||||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||||
# Get app's owner
|
# Get app's owner
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id)
|
stmt = select(Account).where(Account.id == App.created_by).where(App.id == app_id)
|
||||||
user = session.scalar(stmt)
|
user = session.scalar(stmt)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
|
@ -34,13 +34,13 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
|||||||
# needs to patch those methods to avoid database access.
|
# needs to patch those methods to avoid database access.
|
||||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||||
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
|
|
||||||
|
|
||||||
# replace `WorkflowAppGenerator.generate` 's return value.
|
# replace `WorkflowAppGenerator.generate` 's return value.
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||||
)
|
)
|
||||||
|
monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
with pytest.raises(ToolInvokeError) as exc_info:
|
with pytest.raises(ToolInvokeError) as exc_info:
|
||||||
# WorkflowTool always returns a generator, so we need to iterate to
|
# WorkflowTool always returns a generator, so we need to iterate to
|
||||||
|
Loading…
x
Reference in New Issue
Block a user