mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-01 14:02:00 +08:00
Fix/dsl kb encrypt (#17353)
This commit is contained in:
parent
e1304dc0c3
commit
2e9997110a
@ -1,3 +1,5 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -7,6 +9,8 @@ from urllib.parse import urlparse
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml # type: ignore
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
from Crypto.Util.Padding import pad, unpad
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -478,6 +482,15 @@ class AppDslService:
|
|||||||
unique_hash = current_draft_workflow.unique_hash
|
unique_hash = current_draft_workflow.unique_hash
|
||||||
else:
|
else:
|
||||||
unique_hash = None
|
unique_hash = None
|
||||||
|
graph = workflow_data.get("graph", {})
|
||||||
|
for node in graph.get("nodes", []):
|
||||||
|
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||||
|
dataset_ids = node["data"].get("dataset_ids", [])
|
||||||
|
node["data"]["dataset_ids"] = [
|
||||||
|
decrypted_id
|
||||||
|
for dataset_id in dataset_ids
|
||||||
|
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
|
||||||
|
]
|
||||||
workflow_service.sync_draft_workflow(
|
workflow_service.sync_draft_workflow(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
graph=workflow_data.get("graph", {}),
|
graph=workflow_data.get("graph", {}),
|
||||||
@ -552,7 +565,15 @@ class AppDslService:
|
|||||||
if not workflow:
|
if not workflow:
|
||||||
raise ValueError("Missing draft workflow configuration, please check.")
|
raise ValueError("Missing draft workflow configuration, please check.")
|
||||||
|
|
||||||
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
|
workflow_dict = workflow.to_dict(include_secret=include_secret)
|
||||||
|
for node in workflow_dict.get("graph", {}).get("nodes", []):
|
||||||
|
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||||
|
dataset_ids = node["data"].get("dataset_ids", [])
|
||||||
|
node["data"]["dataset_ids"] = [
|
||||||
|
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
|
||||||
|
for dataset_id in dataset_ids
|
||||||
|
]
|
||||||
|
export_data["workflow"] = workflow_dict
|
||||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||||
export_data["dependencies"] = [
|
export_data["dependencies"] = [
|
||||||
jsonable_encoder(d.model_dump())
|
jsonable_encoder(d.model_dump())
|
||||||
@ -724,3 +745,29 @@ class AppDslService:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_aes_key(tenant_id: str) -> bytes:
|
||||||
|
"""Generate AES key based on tenant_id"""
|
||||||
|
return hashlib.sha256(tenant_id.encode()).digest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
|
||||||
|
"""Encrypt dataset_id using AES-CBC mode"""
|
||||||
|
key = cls._generate_aes_key(tenant_id)
|
||||||
|
iv = key[:16]
|
||||||
|
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||||
|
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
|
||||||
|
return base64.b64encode(ct_bytes).decode()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
|
||||||
|
"""AES decryption"""
|
||||||
|
try:
|
||||||
|
key = cls._generate_aes_key(tenant_id)
|
||||||
|
iv = key[:16]
|
||||||
|
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||||
|
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
|
||||||
|
return pt.decode()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user