make sure validation flow works for all model providers in bedrock (#3250)

This commit is contained in:
Chenhe Gu 2024-04-09 20:42:18 +08:00 committed by GitHub
parent e635f3dc1d
commit eb76d7a226
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 16 deletions

View File

@ -74,7 +74,7 @@ provider_credential_schema:
label:
en_US: Available Model Name
zh_Hans: 可用模型名称
type: secret-input
type: text-input
placeholder:
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如amazon.titan-text-lite-v1)

View File

@ -402,25 +402,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param credentials: model credentials
:return:
"""
if "anthropic.claude-3" in model:
try:
self._invoke_claude(model=model,
credentials=credentials,
prompt_messages=[{"role": "user", "content": "ping"}],
model_parameters={},
stop=None,
stream=False)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
required_params = {}
if "anthropic" in model:
required_params = {
"max_tokens": 32,
}
elif "ai21" in model:
# ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again.
required_params = {
"temperature": 0.7,
"topP": 0.9,
"maxTokens": 32,
}
try:
ping_message = UserPromptMessage(content="ping")
self._generate(model=model,
self._invoke(model=model,
credentials=credentials,
prompt_messages=[ping_message],
model_parameters={},
model_parameters=required_params,
stream=False)
except ClientError as ex: