fix: remove harm category setting from vertex ai (#8721)

This commit is contained in:
Shota Totsuka 2024-09-24 21:53:26 +09:00 committed by GitHub
parent 9ca2e2c968
commit 1c7877b048
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@ import base64
import io import io
import json import json
import logging import logging
import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -20,7 +21,6 @@ from google.api_core import exceptions
from google.cloud import aiplatform from google.cloud import aiplatform
from google.oauth2 import service_account from google.oauth2 import service_account
from PIL import Image from PIL import Image
from vertexai.generative_models import HarmBlockThreshold, HarmCategory
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
@ -34,6 +34,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
InvokeBadRequestError, InvokeBadRequestError,
@ -503,20 +504,12 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
else: else:
history.append(content) history.append(content)
safety_settings = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
google_model = glm.GenerativeModel(model_name=model, system_instruction=system_instruction) google_model = glm.GenerativeModel(model_name=model, system_instruction=system_instruction)
response = google_model.generate_content( response = google_model.generate_content(
contents=history, contents=history,
generation_config=glm.GenerationConfig(**config_kwargs), generation_config=glm.GenerationConfig(**config_kwargs),
stream=stream, stream=stream,
safety_settings=safety_settings,
tools=self._convert_tools_to_glm_tool(tools) if tools else None, tools=self._convert_tools_to_glm_tool(tools) if tools else None,
) )