From 2e9997110a1606f200be17a3f6c1e086703651ff Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Thu, 3 Apr 2025 17:29:34 +0800 Subject: [PATCH] Fix/dsl kb encrypt (#17353) --- api/services/app_dsl_service.py | 49 ++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 67ef4c319a..2e2b729021 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -1,3 +1,5 @@ +import base64 +import hashlib import logging import uuid from collections.abc import Mapping @@ -7,6 +9,8 @@ from urllib.parse import urlparse from uuid import uuid4 import yaml # type: ignore +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -478,6 +482,15 @@ class AppDslService: unique_hash = current_draft_workflow.unique_hash else: unique_hash = None + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id)) + ] workflow_service.sync_draft_workflow( app_model=app, graph=workflow_data.get("graph", {}), @@ -552,7 +565,15 @@ class AppDslService: if not workflow: raise ValueError("Missing draft workflow configuration, please check.") - export_data["workflow"] = workflow.to_dict(include_secret=include_secret) + workflow_dict = workflow.to_dict(include_secret=include_secret) + for node in workflow_dict.get("graph", {}).get("nodes", []): + if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) + for dataset_id in dataset_ids + ] + export_data["workflow"] = workflow_dict dependencies = cls._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [ jsonable_encoder(d.model_dump()) @@ -724,3 +745,29 @@ class AppDslService: return [] return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) + + @staticmethod + def _generate_aes_key(tenant_id: str) -> bytes: + """Generate AES key based on tenant_id""" + return hashlib.sha256(tenant_id.encode()).digest() + + @classmethod + def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str: + """Encrypt dataset_id using AES-CBC mode""" + key = cls._generate_aes_key(tenant_id) + iv = key[:16] + cipher = AES.new(key, AES.MODE_CBC, iv) + ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size)) + return base64.b64encode(ct_bytes).decode() + + @classmethod + def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None: + """AES decryption""" + try: + key = cls._generate_aes_key(tenant_id) + iv = key[:16] + cipher = AES.new(key, AES.MODE_CBC, iv) + pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size) + return pt.decode() + except Exception: + return None