fix: hosted moderation

This commit is contained in:
takatost 2024-11-11 20:31:11 +08:00
parent bc0724b499
commit 1d2118fc5d
2 changed files with 21 additions and 8 deletions

View File

@ -24,6 +24,8 @@ class HostingModerationFeature:
if isinstance(prompt_message.content, str): if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n" text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation(model_config, text) moderation_result = moderation.check_moderation(
tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
)
return moderation_result return moderation_result

View File

@ -1,27 +1,32 @@
import logging import logging
import random import random
from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_hosting_provider import hosting_configuration from extensions.ext_hosting_provider import hosting_configuration
from models.provider import ProviderType from models.provider import ProviderType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
moderation_config = hosting_configuration.moderation_config moderation_config = hosting_configuration.moderation_config
openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
if ( if (
moderation_config moderation_config
and moderation_config.enabled is True and moderation_config.enabled is True
and "openai" in hosting_configuration.provider_map and openai_provider_name in hosting_configuration.provider_map
and hosting_configuration.provider_map["openai"].enabled is True and hosting_configuration.provider_map[openai_provider_name].enabled is True
): ):
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
provider_name = model_config.provider provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map["openai"] hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]
if hosting_openai_config.credentials is None: if hosting_openai_config.credentials is None:
return False return False
@ -36,9 +41,15 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
text_chunk = random.choice(text_chunks) text_chunk = random.choice(text_chunks)
try: try:
model_type_instance = OpenAIModerationModel() model_provider_factory = ModelProviderFactory(tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)
moderation_result = model_type_instance.invoke( moderation_result = model_type_instance.invoke(
model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
) )
if moderation_result is True: if moderation_result is True: