From eb76d7a226533058740231f9265a71953b77a03c Mon Sep 17 00:00:00 2001 From: Chenhe Gu Date: Tue, 9 Apr 2024 20:42:18 +0800 Subject: [PATCH] make sure validation flow works for all model providers in bedrock (#3250) --- .../model_providers/bedrock/bedrock.yaml | 2 +- .../model_providers/bedrock/llm/llm.py | 30 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml index 35374c69ba..e1923f8f8a 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml @@ -74,7 +74,7 @@ provider_credential_schema: label: en_US: Available Model Name zh_Hans: 可用模型名称 - type: secret-input + type: text-input placeholder: en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation. zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 0e256999c0..3dbbcb9b4f 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -402,25 +402,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param credentials: model credentials :return: """ - - if "anthropic.claude-3" in model: - try: - self._invoke_claude(model=model, - credentials=credentials, - prompt_messages=[{"role": "user", "content": "ping"}], - model_parameters={}, - stop=None, - stream=False) - - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - + required_params = {} + if "anthropic" in model: + required_params = { + "max_tokens": 32, + } + elif "ai21" in model: + # ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again. + required_params = { + "temperature": 0.7, + "topP": 0.9, + "maxTokens": 32, + } + try: ping_message = UserPromptMessage(content="ping") - self._generate(model=model, + self._invoke(model=model, credentials=credentials, prompt_messages=[ping_message], - model_parameters={}, + model_parameters=required_params, stream=False) except ClientError as ex: