Fix/dsl kb encrypt (#17353)

This commit is contained in:
Dongyu Li 2025-04-03 17:29:34 +08:00 committed by GitHub
parent e1304dc0c3
commit 2e9997110a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,5 @@
import base64
import hashlib
import logging import logging
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
@ -7,6 +9,8 @@ from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
import yaml # type: ignore import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version from packaging import version
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
@ -478,6 +482,15 @@ class AppDslService:
unique_hash = current_draft_workflow.unique_hash unique_hash = current_draft_workflow.unique_hash
else: else:
unique_hash = None unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
]
workflow_service.sync_draft_workflow( workflow_service.sync_draft_workflow(
app_model=app, app_model=app,
graph=workflow_data.get("graph", {}), graph=workflow_data.get("graph", {}),
@ -552,7 +565,15 @@ class AppDslService:
if not workflow: if not workflow:
raise ValueError("Missing draft workflow configuration, please check.") raise ValueError("Missing draft workflow configuration, please check.")
export_data["workflow"] = workflow.to_dict(include_secret=include_secret) workflow_dict = workflow.to_dict(include_secret=include_secret)
for node in workflow_dict.get("graph", {}).get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow) dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [ export_data["dependencies"] = [
jsonable_encoder(d.model_dump()) jsonable_encoder(d.model_dump())
@ -724,3 +745,29 @@ class AppDslService:
return [] return []
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()
@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()
@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
try:
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
return pt.decode()
except Exception:
return None