mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-05-28 00:58:20 +08:00
97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
from typing import Optional
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
|
|
from core.helper.encrypter import decrypt_token
|
|
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
|
from extensions.ext_database import db
|
|
from models.api_based_extension import APIBasedExtension
|
|
|
|
|
|
class ModerationInputParams(BaseModel):
|
|
app_id: str = ""
|
|
inputs: dict = {}
|
|
query: str = ""
|
|
|
|
|
|
class ModerationOutputParams(BaseModel):
|
|
app_id: str = ""
|
|
text: str
|
|
|
|
|
|
class ApiModeration(Moderation):
|
|
name: str = "api"
|
|
|
|
@classmethod
|
|
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
|
"""
|
|
Validate the incoming form config data.
|
|
|
|
:param tenant_id: the id of workspace
|
|
:param config: the form config data
|
|
:return:
|
|
"""
|
|
cls._validate_inputs_and_outputs_config(config, False)
|
|
|
|
api_based_extension_id = config.get("api_based_extension_id")
|
|
if not api_based_extension_id:
|
|
raise ValueError("api_based_extension_id is required")
|
|
|
|
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
|
|
if not extension:
|
|
raise ValueError("API-based Extension not found. Please check it again.")
|
|
|
|
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
|
flagged = False
|
|
preset_response = ""
|
|
if self.config is None:
|
|
raise ValueError("The config is not set.")
|
|
|
|
if self.config["inputs_config"]["enabled"]:
|
|
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
|
|
|
|
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
|
|
return ModerationInputsResult(**result)
|
|
|
|
return ModerationInputsResult(
|
|
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
|
)
|
|
|
|
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
|
flagged = False
|
|
preset_response = ""
|
|
if self.config is None:
|
|
raise ValueError("The config is not set.")
|
|
|
|
if self.config["outputs_config"]["enabled"]:
|
|
params = ModerationOutputParams(app_id=self.app_id, text=text)
|
|
|
|
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
|
|
return ModerationOutputsResult(**result)
|
|
|
|
return ModerationOutputsResult(
|
|
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
|
)
|
|
|
|
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
|
|
if self.config is None:
|
|
raise ValueError("The config is not set.")
|
|
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
|
if not extension:
|
|
raise ValueError("API-based Extension not found. Please check it again.")
|
|
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
|
|
|
|
result = requestor.request(extension_point, params)
|
|
return result
|
|
|
|
@staticmethod
|
|
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
|
|
extension = (
|
|
db.session.query(APIBasedExtension)
|
|
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
|
.first()
|
|
)
|
|
|
|
return extension
|