fix: gemini block error (#1877)

Co-authored-by: chenhe <guchenhe@gmail.com>
This commit is contained in:
takatost 2024-01-03 17:45:15 +08:00 committed by GitHub
parent 61aaeff413
commit ede69b4659
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 9 deletions

View File

@ -132,8 +132,8 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None
real_model = model
for chunk in result:
try:
try:
for chunk in result:
yield chunk
self._trigger_new_chunk_callbacks(
@ -156,8 +156,8 @@ class LargeLanguageModel(AIModel):
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception as e:
raise self._transform_invoke_error(e)
except Exception as e:
raise self._transform_invoke_error(e)
self._trigger_after_invoke_callbacks(
model=model,

View File

@ -3,6 +3,7 @@ from typing import Optional, Generator, Union, List
import google.generativeai as genai
import google.api_core.exceptions as exceptions
import google.generativeai.client as client
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from google.generativeai.types import GenerateContentResponse, ContentType
from google.generativeai.types.content_types import to_part
@ -140,12 +141,20 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
google_model._client = new_custom_client
safety_settings={
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(
**config_kwargs
),
stream=stream
stream=stream,
safety_settings=safety_settings
)
if stream:
@ -169,7 +178,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content=response.text
)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])