mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 18:39:06 +08:00
feat: add hosted moderation (#1158)
This commit is contained in:
parent
983834cd52
commit
f9082104ed
@ -61,6 +61,8 @@ DEFAULTS = {
|
|||||||
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
|
||||||
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
|
||||||
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
|
||||||
|
'HOSTED_MODERATION_ENABLED': 'False',
|
||||||
|
'HOSTED_MODERATION_PROVIDERS': '',
|
||||||
'TENANT_DOCUMENT_COUNT': 100,
|
'TENANT_DOCUMENT_COUNT': 100,
|
||||||
'CLEAN_DAY_SETTING': 30,
|
'CLEAN_DAY_SETTING': 30,
|
||||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||||
@ -230,6 +232,9 @@ class Config:
|
|||||||
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
|
||||||
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
|
||||||
|
|
||||||
|
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||||
|
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||||
|
|
||||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||||
|
|
||||||
|
@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
|
|||||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||||
|
|
||||||
|
from core.helper import moderation
|
||||||
|
from core.model_providers.error import LLMError
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
|
||||||
@ -116,6 +118,18 @@ class AgentExecutor:
|
|||||||
return self.agent.should_use_agent(query)
|
return self.agent.should_use_agent(query)
|
||||||
|
|
||||||
def run(self, query: str) -> AgentExecuteResult:
|
def run(self, query: str) -> AgentExecuteResult:
|
||||||
|
moderation_result = moderation.check_moderation(
|
||||||
|
self.configuration.model_instance.model_provider,
|
||||||
|
query
|
||||||
|
)
|
||||||
|
|
||||||
|
if not moderation_result:
|
||||||
|
return AgentExecuteResult(
|
||||||
|
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||||
|
strategy=self.configuration.strategy,
|
||||||
|
configuration=self.configuration
|
||||||
|
)
|
||||||
|
|
||||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
tools=self.configuration.tools,
|
tools=self.configuration.tools,
|
||||||
@ -128,7 +142,9 @@ class AgentExecutor:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
output = agent_executor.run(query)
|
output = agent_executor.run(query)
|
||||||
except Exception:
|
except LLMError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
logging.exception("agent_executor run failed")
|
logging.exception("agent_executor run failed")
|
||||||
output = None
|
output = None
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
|
|||||||
|
|
||||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||||
|
|
||||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
from core.conversation_message_task import ConversationMessageTask
|
||||||
@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
raise_error: bool = True
|
raise_error: bool = True
|
||||||
|
|
||||||
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
|
||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
self.model_instant = model_instant
|
self.model_instance = model_instance
|
||||||
self.conversation_message_task = conversation_message_task
|
self.conversation_message_task = conversation_message_task
|
||||||
self._agent_loops = []
|
self._agent_loops = []
|
||||||
self._current_loop = None
|
self._current_loop = None
|
||||||
@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Whether to ignore chain callbacks."""
|
"""Whether to ignore chain callbacks."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
if not self._current_loop:
|
||||||
|
# Agent start with a LLM query
|
||||||
|
self._current_loop = AgentLoop(
|
||||||
|
position=len(self._agent_loops) + 1,
|
||||||
|
prompt="\n".join([message.content for message in messages[0]]),
|
||||||
|
status='llm_started',
|
||||||
|
started_at=time.perf_counter()
|
||||||
|
)
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
if response.llm_output:
|
if response.llm_output:
|
||||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||||
else:
|
else:
|
||||||
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
|
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
|
||||||
[PromptMessage(content=self._current_loop.prompt)]
|
[PromptMessage(content=self._current_loop.prompt)]
|
||||||
)
|
)
|
||||||
completion_generation = response.generations[0][0]
|
completion_generation = response.generations[0][0]
|
||||||
@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
if response.llm_output:
|
if response.llm_output:
|
||||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||||
else:
|
else:
|
||||||
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
|
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
|
||||||
[PromptMessage(content=self._current_loop.completion)]
|
[PromptMessage(content=self._current_loop.completion)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
|
||||||
|
|
||||||
self.conversation_message_task.on_agent_end(
|
self.conversation_message_task.on_agent_end(
|
||||||
self._message_agent_thought, self.model_instant, self._current_loop
|
self._message_agent_thought, self.model_instance, self._current_loop
|
||||||
)
|
)
|
||||||
|
|
||||||
self._agent_loops.append(self._current_loop)
|
self._agent_loops.append(self._current_loop)
|
||||||
@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.conversation_message_task.on_agent_end(
|
self.conversation_message_task.on_agent_end(
|
||||||
self._message_agent_thought, self.model_instant, self._current_loop
|
self._message_agent_thought, self.model_instance, self._current_loop
|
||||||
)
|
)
|
||||||
|
|
||||||
self._agent_loops.append(self._current_loop)
|
self._agent_loops.append(self._current_loop)
|
||||||
|
@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
|
|||||||
prompt_tokens: int = 0
|
prompt_tokens: int = 0
|
||||||
completion: str = ''
|
completion: str = ''
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
latency: float = 0.0
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
self.start_at = time.perf_counter()
|
|
||||||
real_prompts = []
|
real_prompts = []
|
||||||
for message in messages[0]:
|
for message in messages[0]:
|
||||||
if message.type == 'human':
|
if message.type == 'human':
|
||||||
@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
self.start_at = time.perf_counter()
|
|
||||||
|
|
||||||
self.llm_message.prompt = [{
|
self.llm_message.prompt = [{
|
||||||
"role": 'user',
|
"role": 'user',
|
||||||
"text": prompts[0]
|
"text": prompts[0]
|
||||||
@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
end_at = time.perf_counter()
|
|
||||||
self.llm_message.latency = end_at - self.start_at
|
|
||||||
|
|
||||||
if not self.conversation_message_task.streaming:
|
if not self.conversation_message_task.streaming:
|
||||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||||
self.llm_message.completion = response.generations[0][0].text
|
self.llm_message.completion = response.generations[0][0].text
|
||||||
@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
if isinstance(error, ConversationTaskStoppedException):
|
if isinstance(error, ConversationTaskStoppedException):
|
||||||
if self.conversation_message_task.streaming:
|
if self.conversation_message_task.streaming:
|
||||||
end_at = time.perf_counter()
|
|
||||||
self.llm_message.latency = end_at - self.start_at
|
|
||||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||||
[PromptMessage(content=self.llm_message.completion)]
|
[PromptMessage(content=self.llm_message.completion)]
|
||||||
)
|
)
|
||||||
|
@ -1,15 +1,38 @@
|
|||||||
|
import enum
|
||||||
|
import logging
|
||||||
from typing import List, Dict, Optional, Any
|
from typing import List, Dict, Optional, Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from flask import current_app
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
from openai import InvalidRequestError
|
||||||
|
from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \
|
||||||
|
AuthenticationError, OpenAIError
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.model_providers.error import LLMBadRequestError
|
||||||
|
from core.model_providers.model_factory import ModelFactory
|
||||||
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
|
from core.model_providers.models.moderation import openai_moderation
|
||||||
|
|
||||||
|
|
||||||
|
class SensitiveWordAvoidanceRule(BaseModel):
|
||||||
|
class Type(enum.Enum):
|
||||||
|
MODERATION = "moderation"
|
||||||
|
KEYWORDS = "keywords"
|
||||||
|
|
||||||
|
type: Type
|
||||||
|
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
|
||||||
|
extra_params: dict = {}
|
||||||
|
|
||||||
|
|
||||||
class SensitiveWordAvoidanceChain(Chain):
|
class SensitiveWordAvoidanceChain(Chain):
|
||||||
input_key: str = "input" #: :meta private:
|
input_key: str = "input" #: :meta private:
|
||||||
output_key: str = "output" #: :meta private:
|
output_key: str = "output" #: :meta private:
|
||||||
|
|
||||||
sensitive_words: List[str] = []
|
model_instance: BaseLLM
|
||||||
canned_response: str = None
|
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain):
|
|||||||
"""
|
"""
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def _check_sensitive_word(self, text: str) -> str:
|
def _check_sensitive_word(self, text: str) -> bool:
|
||||||
for word in self.sensitive_words:
|
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
|
||||||
if word in text:
|
if word in text:
|
||||||
return self.canned_response
|
return False
|
||||||
return text
|
return True
|
||||||
|
|
||||||
|
def _check_moderation(self, text: str) -> bool:
|
||||||
|
moderation_model_instance = ModelFactory.get_moderation_model(
|
||||||
|
tenant_id=self.model_instance.model_provider.provider.tenant_id,
|
||||||
|
model_provider_name='openai',
|
||||||
|
model_name=openai_moderation.DEFAULT_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return moderation_model_instance.run(text=text)
|
||||||
|
except Exception as ex:
|
||||||
|
logging.exception(ex)
|
||||||
|
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain):
|
|||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
text = inputs[self.input_key]
|
text = inputs[self.input_key]
|
||||||
output = self._check_sensitive_word(text)
|
|
||||||
return {self.output_key: output}
|
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
|
||||||
|
result = self._check_sensitive_word(text)
|
||||||
|
else:
|
||||||
|
result = self._check_moderation(text)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response)
|
||||||
|
|
||||||
|
return {self.output_key: text}
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
from typing import Optional, List, Union
|
||||||
from typing import Optional, List, Union, Tuple
|
|
||||||
|
|
||||||
from langchain.schema import BaseMessage
|
|
||||||
from requests.exceptions import ChunkedEncodingError
|
from requests.exceptions import ChunkedEncodingError
|
||||||
|
|
||||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||||
@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError
|
|||||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_providers.model_factory import ModelFactory
|
||||||
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
|
from core.model_providers.models.entity.message import PromptMessage
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||||
from core.prompt.prompt_builder import PromptBuilder
|
from core.prompt.prompt_builder import PromptBuilder
|
||||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
|
||||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||||
from models.dataset import DocumentSegment, Dataset, Document
|
from models.dataset import DocumentSegment, Dataset, Document
|
||||||
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
||||||
@ -81,7 +78,7 @@ class Completion:
|
|||||||
|
|
||||||
# parse sensitive_word_avoidance_chain
|
# parse sensitive_word_avoidance_chain
|
||||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback])
|
||||||
if sensitive_word_avoidance_chain:
|
if sensitive_word_avoidance_chain:
|
||||||
query = sensitive_word_avoidance_chain.run(query)
|
query = sensitive_word_avoidance_chain.run(query)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import decimal
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List
|
||||||
|
|
||||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||||
@ -23,6 +23,8 @@ class ConversationMessageTask:
|
|||||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||||
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
|
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
|
||||||
conversation: Optional[Conversation] = None, is_override: bool = False):
|
conversation: Optional[Conversation] = None, is_override: bool = False):
|
||||||
|
self.start_at = time.perf_counter()
|
||||||
|
|
||||||
self.task_id = task_id
|
self.task_id = task_id
|
||||||
|
|
||||||
self.app = app
|
self.app = app
|
||||||
@ -61,6 +63,7 @@ class ConversationMessageTask:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
|
|
||||||
override_model_configs = None
|
override_model_configs = None
|
||||||
if self.is_override:
|
if self.is_override:
|
||||||
override_model_configs = self.app_model_config.to_dict()
|
override_model_configs = self.app_model_config.to_dict()
|
||||||
@ -165,7 +168,7 @@ class ConversationMessageTask:
|
|||||||
self.message.answer_tokens = answer_tokens
|
self.message.answer_tokens = answer_tokens
|
||||||
self.message.answer_unit_price = answer_unit_price
|
self.message.answer_unit_price = answer_unit_price
|
||||||
self.message.answer_price_unit = answer_price_unit
|
self.message.answer_price_unit = answer_price_unit
|
||||||
self.message.provider_response_latency = llm_message.latency
|
self.message.provider_response_latency = time.perf_counter() - self.start_at
|
||||||
self.message.total_price = total_price
|
self.message.total_price = total_price
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -220,18 +223,18 @@ class ConversationMessageTask:
|
|||||||
|
|
||||||
return message_agent_thought
|
return message_agent_thought
|
||||||
|
|
||||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
||||||
agent_loop: AgentLoop):
|
agent_loop: AgentLoop):
|
||||||
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||||
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
|
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
|
||||||
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||||
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
|
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||||
|
|
||||||
loop_message_tokens = agent_loop.prompt_tokens
|
loop_message_tokens = agent_loop.prompt_tokens
|
||||||
loop_answer_tokens = agent_loop.completion_tokens
|
loop_answer_tokens = agent_loop.completion_tokens
|
||||||
|
|
||||||
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||||
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||||
|
|
||||||
message_agent_thought.observation = agent_loop.tool_output
|
message_agent_thought.observation = agent_loop.tool_output
|
||||||
@ -245,7 +248,7 @@ class ConversationMessageTask:
|
|||||||
message_agent_thought.latency = agent_loop.latency
|
message_agent_thought.latency = agent_loop.latency
|
||||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||||
message_agent_thought.total_price = loop_total_price
|
message_agent_thought.total_price = loop_total_price
|
||||||
message_agent_thought.currency = agent_model_instant.get_currency()
|
message_agent_thought.currency = agent_model_instance.get_currency()
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||||
|
32
api/core/helper/moderation.py
Normal file
32
api/core/helper/moderation.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from core.model_providers.error import LLMBadRequestError
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||||
|
if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']:
|
||||||
|
moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',')
|
||||||
|
|
||||||
|
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||||
|
and model_provider.provider_name in moderation_providers:
|
||||||
|
# 2000 text per chunk
|
||||||
|
length = 2000
|
||||||
|
chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
moderation_result = openai.Moderation.create(input=chunks,
|
||||||
|
api_key=current_app.config['HOSTED_OPENAI_API_KEY'])
|
||||||
|
except Exception as ex:
|
||||||
|
logging.exception(ex)
|
||||||
|
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||||
|
|
||||||
|
for result in moderation_result.results:
|
||||||
|
if result['flagged'] is True:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
|
|||||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
|
from core.model_providers.models.moderation.base import BaseModeration
|
||||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.provider import TenantDefaultModel
|
from models.provider import TenantDefaultModel
|
||||||
@ -180,7 +181,7 @@ class ModelFactory:
|
|||||||
def get_moderation_model(cls,
|
def get_moderation_model(cls,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
model_provider_name: str,
|
model_provider_name: str,
|
||||||
model_name: str) -> Optional[BaseProviderModel]:
|
model_name: str) -> Optional[BaseModeration]:
|
||||||
"""
|
"""
|
||||||
get moderation model.
|
get moderation model.
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
|
|||||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||||
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||||
|
from core.helper import moderation
|
||||||
from core.model_providers.models.base import BaseProviderModel
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||||
@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
|
|||||||
:param callbacks:
|
:param callbacks:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
moderation_result = moderation.check_moderation(
|
||||||
|
self.model_provider,
|
||||||
|
"\n".join([message.content for message in messages])
|
||||||
|
)
|
||||||
|
|
||||||
|
if not moderation_result:
|
||||||
|
kwargs['fake_response'] = "I apologize for any confusion, " \
|
||||||
|
"but I'm an AI assistant to be helpful, harmless, and honest."
|
||||||
|
|
||||||
if self.deduct_quota:
|
if self.deduct_quota:
|
||||||
self.model_provider.check_quota_over_limit()
|
self.model_provider.check_quota_over_limit()
|
||||||
|
|
||||||
|
29
api/core/model_providers/models/moderation/base.py
Normal file
29
api/core/model_providers/models/moderation/base.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModeration(BaseProviderModel):
|
||||||
|
name: str
|
||||||
|
type: ModelType = ModelType.MODERATION
|
||||||
|
|
||||||
|
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||||
|
super().__init__(model_provider, client)
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def run(self, text: str) -> bool:
|
||||||
|
try:
|
||||||
|
return self._run(text)
|
||||||
|
except Exception as ex:
|
||||||
|
raise self.handle_exceptions(ex)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run(self, text: str) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
raise NotImplementedError
|
@ -4,29 +4,35 @@ import openai
|
|||||||
|
|
||||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||||
LLMRateLimitError, LLMAuthorizationError
|
LLMRateLimitError, LLMAuthorizationError
|
||||||
from core.model_providers.models.base import BaseProviderModel
|
from core.model_providers.models.moderation.base import BaseModeration
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
|
||||||
from core.model_providers.providers.base import BaseModelProvider
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
|
||||||
DEFAULT_AUDIO_MODEL = 'whisper-1'
|
DEFAULT_MODEL = 'whisper-1'
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModeration(BaseProviderModel):
|
class OpenAIModeration(BaseModeration):
|
||||||
type: ModelType = ModelType.MODERATION
|
|
||||||
|
|
||||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||||
super().__init__(model_provider, openai.Moderation)
|
super().__init__(model_provider, openai.Moderation, name)
|
||||||
|
|
||||||
def run(self, text):
|
def _run(self, text: str) -> bool:
|
||||||
credentials = self.model_provider.get_model_credentials(
|
credentials = self.model_provider.get_model_credentials(
|
||||||
model_name=DEFAULT_AUDIO_MODEL,
|
model_name=self.name,
|
||||||
model_type=self.type
|
model_type=self.type
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# 2000 text per chunk
|
||||||
return self._client.create(input=text, api_key=credentials['openai_api_key'])
|
length = 2000
|
||||||
except Exception as ex:
|
chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||||
raise self.handle_exceptions(ex)
|
|
||||||
|
moderation_result = self._client.create(input=chunks,
|
||||||
|
api_key=credentials['openai_api_key'])
|
||||||
|
|
||||||
|
for result in moderation_result.results:
|
||||||
|
if result['flagged'] is True:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
if isinstance(ex, openai.error.InvalidRequestError):
|
if isinstance(ex, openai.error.InvalidRequestError):
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
from langchain import WikipediaAPIWrapper
|
from langchain import WikipediaAPIWrapper
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
|
|||||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
from core.conversation_message_task import ConversationMessageTask
|
||||||
from core.model_providers.error import ProviderTokenNotInitError
|
from core.model_providers.error import ProviderTokenNotInitError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_providers.model_factory import ModelFactory
|
||||||
@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
from models.model import AppModelConfig
|
from models.model import AppModelConfig
|
||||||
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
class OrchestratorRuleParser:
|
class OrchestratorRuleParser:
|
||||||
@ -63,7 +65,7 @@ class OrchestratorRuleParser:
|
|||||||
|
|
||||||
# add agent callback to record agent thoughts
|
# add agent callback to record agent thoughts
|
||||||
agent_callback = AgentLoopGatherCallbackHandler(
|
agent_callback = AgentLoopGatherCallbackHandler(
|
||||||
model_instant=agent_model_instance,
|
model_instance=agent_model_instance,
|
||||||
conversation_message_task=conversation_message_task
|
conversation_message_task=conversation_message_task
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -123,23 +125,45 @@ class OrchestratorRuleParser:
|
|||||||
|
|
||||||
return chain
|
return chain
|
||||||
|
|
||||||
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
|
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
|
||||||
-> Optional[SensitiveWordAvoidanceChain]:
|
-> Optional[SensitiveWordAvoidanceChain]:
|
||||||
"""
|
"""
|
||||||
Convert app sensitive word avoidance config to chain
|
Convert app sensitive word avoidance config to chain
|
||||||
|
|
||||||
|
:param model_instance: model instance
|
||||||
|
:param callbacks: callbacks for the chain
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if not self.app_model_config.sensitive_word_avoidance_dict:
|
sensitive_word_avoidance_rule = None
|
||||||
return None
|
|
||||||
|
|
||||||
|
if self.app_model_config.sensitive_word_avoidance_dict:
|
||||||
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||||
|
if sensitive_word_avoidance_config.get("enabled", False):
|
||||||
|
if sensitive_word_avoidance_config.get('type') == 'moderation':
|
||||||
|
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
|
||||||
|
type=SensitiveWordAvoidanceRule.Type.MODERATION,
|
||||||
|
canned_response=sensitive_word_avoidance_config.get("canned_response")
|
||||||
|
if sensitive_word_avoidance_config.get("canned_response")
|
||||||
|
else 'Your content violates our usage policy. Please revise and try again.',
|
||||||
|
)
|
||||||
|
else:
|
||||||
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||||
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
|
if sensitive_words:
|
||||||
|
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
|
||||||
|
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
|
||||||
|
canned_response=sensitive_word_avoidance_config.get("canned_response")
|
||||||
|
if sensitive_word_avoidance_config.get("canned_response")
|
||||||
|
else 'Your content violates our usage policy. Please revise and try again.',
|
||||||
|
extra_params={
|
||||||
|
'sensitive_words': sensitive_words.split(','),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if sensitive_word_avoidance_rule:
|
||||||
return SensitiveWordAvoidanceChain(
|
return SensitiveWordAvoidanceChain(
|
||||||
sensitive_words=sensitive_words.split(","),
|
model_instance=model_instance,
|
||||||
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
|
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
|
||||||
output_key="sensitive_word_avoidance_output",
|
output_key="sensitive_word_avoidance_output",
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL
|
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_MODEL
|
||||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||||
from models.provider import Provider, ProviderType
|
from models.provider import Provider, ProviderType
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ def get_mock_openai_moderation_model():
|
|||||||
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
|
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
|
||||||
return OpenAIModeration(
|
return OpenAIModeration(
|
||||||
model_provider=openai_provider,
|
model_provider=openai_provider,
|
||||||
name=DEFAULT_AUDIO_MODEL
|
name=DEFAULT_MODEL
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -36,5 +36,4 @@ def test_run(mock_decrypt):
|
|||||||
model = get_mock_openai_moderation_model()
|
model = get_mock_openai_moderation_model()
|
||||||
rst = model.run('hello')
|
rst = model.run('hello')
|
||||||
|
|
||||||
assert isinstance(rst, dict)
|
assert rst is True
|
||||||
assert 'id' in rst
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user