mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 14:06:15 +08:00
feat: bedrock invoke enhancement (#6808)
This commit is contained in:
parent
98d9837fbc
commit
9ce5cea911
@ -12,6 +12,7 @@
|
|||||||
- cohere.command-r-v1.0
|
- cohere.command-r-v1.0
|
||||||
- meta.llama3-1-8b-instruct-v1:0
|
- meta.llama3-1-8b-instruct-v1:0
|
||||||
- meta.llama3-1-70b-instruct-v1:0
|
- meta.llama3-1-70b-instruct-v1:0
|
||||||
|
- meta.llama3-1-405b-instruct-v1:0
|
||||||
- meta.llama3-8b-instruct-v1:0
|
- meta.llama3-8b-instruct-v1:0
|
||||||
- meta.llama3-70b-instruct-v1:0
|
- meta.llama3-70b-instruct-v1:0
|
||||||
- meta.llama2-13b-chat-v1
|
- meta.llama2-13b-chat-v1
|
||||||
|
@ -3,8 +3,7 @@ label:
|
|||||||
en_US: Command R+
|
en_US: Command R+
|
||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
#- multi-tool-call
|
- tool-call
|
||||||
- agent-thought
|
|
||||||
#- stream-tool-call
|
#- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
|
@ -3,9 +3,7 @@ label:
|
|||||||
en_US: Command R
|
en_US: Command R
|
||||||
model_type: llm
|
model_type: llm
|
||||||
features:
|
features:
|
||||||
#- multi-tool-call
|
- tool-call
|
||||||
- agent-thought
|
|
||||||
#- stream-tool-call
|
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 128000
|
context_size: 128000
|
||||||
|
@ -17,7 +17,6 @@ from botocore.exceptions import (
|
|||||||
ServiceNotInRegionError,
|
ServiceNotInRegionError,
|
||||||
UnknownServiceError,
|
UnknownServiceError,
|
||||||
)
|
)
|
||||||
from cohere import ChatMessage
|
|
||||||
|
|
||||||
# local import
|
# local import
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
@ -42,7 +41,6 @@ from core.model_runtime.errors.invoke import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -59,6 +57,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||||
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
|
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
|
||||||
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
|
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
|
||||||
|
{'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True},
|
||||||
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
|
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -94,87 +93,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
model_info['model'] = model
|
model_info['model'] = model
|
||||||
# invoke models via boto3 converse API
|
# invoke models via boto3 converse API
|
||||||
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||||
# invoke Cohere models via boto3 client
|
|
||||||
if "cohere.command-r" in model:
|
|
||||||
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
|
||||||
# invoke other models via boto3 client
|
# invoke other models via boto3 client
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||||
|
|
||||||
def _generate_cohere_chat(
|
|
||||||
self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
|
||||||
cohere_llm = CohereLargeLanguageModel()
|
|
||||||
client_config = Config(
|
|
||||||
region_name=credentials["aws_region"]
|
|
||||||
)
|
|
||||||
|
|
||||||
runtime_client = boto3.client(
|
|
||||||
service_name='bedrock-runtime',
|
|
||||||
config=client_config,
|
|
||||||
aws_access_key_id=credentials["aws_access_key_id"],
|
|
||||||
aws_secret_access_key=credentials["aws_secret_access_key"]
|
|
||||||
)
|
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
|
||||||
if stop:
|
|
||||||
extra_model_kwargs['stop_sequences'] = stop
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
tools = cohere_llm._convert_tools(tools)
|
|
||||||
model_parameters['tools'] = tools
|
|
||||||
|
|
||||||
message, chat_histories, tool_results \
|
|
||||||
= cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
|
||||||
|
|
||||||
if tool_results:
|
|
||||||
model_parameters['tool_results'] = tool_results
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
**model_parameters,
|
|
||||||
"message": message,
|
|
||||||
"chat_history": chat_histories,
|
|
||||||
}
|
|
||||||
|
|
||||||
# need workaround for ai21 models which doesn't support streaming
|
|
||||||
if stream:
|
|
||||||
invoke = runtime_client.invoke_model_with_response_stream
|
|
||||||
else:
|
|
||||||
invoke = runtime_client.invoke_model
|
|
||||||
|
|
||||||
def serialize(obj):
|
|
||||||
if isinstance(obj, ChatMessage):
|
|
||||||
return obj.__dict__
|
|
||||||
raise TypeError(f"Type {type(obj)} not serializable")
|
|
||||||
|
|
||||||
try:
|
|
||||||
body_jsonstr=json.dumps(payload, default=serialize)
|
|
||||||
response = invoke(
|
|
||||||
modelId=model,
|
|
||||||
contentType="application/json",
|
|
||||||
accept="*/*",
|
|
||||||
body=body_jsonstr
|
|
||||||
)
|
|
||||||
except ClientError as ex:
|
|
||||||
error_code = ex.response['Error']['Code']
|
|
||||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
|
||||||
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
|
||||||
|
|
||||||
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
|
||||||
raise InvokeConnectionError(str(ex))
|
|
||||||
|
|
||||||
except UnknownServiceError as ex:
|
|
||||||
raise InvokeServerUnavailableError(str(ex))
|
|
||||||
|
|
||||||
except Exception as ex:
|
|
||||||
raise InvokeError(str(ex))
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
@ -581,35 +502,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param message: PromptMessage to convert.
|
:param message: PromptMessage to convert.
|
||||||
:return: String representation of the message.
|
:return: String representation of the message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if model_prefix == "anthropic":
|
|
||||||
human_prompt_prefix = "\n\nHuman:"
|
|
||||||
human_prompt_postfix = ""
|
|
||||||
ai_prompt = "\n\nAssistant:"
|
|
||||||
|
|
||||||
elif model_prefix == "meta":
|
|
||||||
# LLAMA3
|
|
||||||
if model_name.startswith("llama3"):
|
|
||||||
human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
|
||||||
human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
ai_prompt = "\n\nAssistant:"
|
|
||||||
else:
|
|
||||||
# LLAMA2
|
|
||||||
human_prompt_prefix = "\n[INST]"
|
|
||||||
human_prompt_postfix = "[\\INST]\n"
|
|
||||||
ai_prompt = ""
|
|
||||||
|
|
||||||
elif model_prefix == "mistral":
|
|
||||||
human_prompt_prefix = "<s>[INST]"
|
|
||||||
human_prompt_postfix = "[\\INST]\n"
|
|
||||||
ai_prompt = "\n\nAssistant:"
|
|
||||||
|
|
||||||
elif model_prefix == "amazon":
|
|
||||||
human_prompt_prefix = "\n\nUser:"
|
|
||||||
human_prompt_postfix = ""
|
|
||||||
ai_prompt = "\n\nBot:"
|
|
||||||
|
|
||||||
else:
|
|
||||||
human_prompt_prefix = ""
|
human_prompt_prefix = ""
|
||||||
human_prompt_postfix = ""
|
human_prompt_postfix = ""
|
||||||
ai_prompt = ""
|
ai_prompt = ""
|
||||||
@ -663,13 +555,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
model_prefix = model.split('.')[0]
|
model_prefix = model.split('.')[0]
|
||||||
model_name = model.split('.')[1]
|
model_name = model.split('.')[1]
|
||||||
|
|
||||||
if model_prefix == "amazon":
|
if model_prefix == "ai21":
|
||||||
payload["textGenerationConfig"] = { **model_parameters }
|
|
||||||
payload["textGenerationConfig"]["stopSequences"] = ["User:"]
|
|
||||||
|
|
||||||
payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
|
||||||
|
|
||||||
elif model_prefix == "ai21":
|
|
||||||
payload["temperature"] = model_parameters.get("temperature")
|
payload["temperature"] = model_parameters.get("temperature")
|
||||||
payload["topP"] = model_parameters.get("topP")
|
payload["topP"] = model_parameters.get("topP")
|
||||||
payload["maxTokens"] = model_parameters.get("maxTokens")
|
payload["maxTokens"] = model_parameters.get("maxTokens")
|
||||||
@ -682,27 +568,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
if model_parameters.get("countPenalty"):
|
if model_parameters.get("countPenalty"):
|
||||||
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
||||||
|
|
||||||
elif model_prefix == "mistral":
|
|
||||||
payload["temperature"] = model_parameters.get("temperature")
|
|
||||||
payload["top_p"] = model_parameters.get("top_p")
|
|
||||||
payload["max_tokens"] = model_parameters.get("max_tokens")
|
|
||||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
|
||||||
payload["stop"] = stop[:10] if stop else []
|
|
||||||
|
|
||||||
elif model_prefix == "anthropic":
|
|
||||||
payload = { **model_parameters }
|
|
||||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
|
||||||
payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else [])
|
|
||||||
|
|
||||||
elif model_prefix == "cohere":
|
elif model_prefix == "cohere":
|
||||||
payload = { **model_parameters }
|
payload = { **model_parameters }
|
||||||
payload["prompt"] = prompt_messages[0].content
|
payload["prompt"] = prompt_messages[0].content
|
||||||
payload["stream"] = stream
|
payload["stream"] = stream
|
||||||
|
|
||||||
elif model_prefix == "meta":
|
|
||||||
payload = { **model_parameters }
|
|
||||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown model prefix {model_prefix}")
|
raise ValueError(f"Got unknown model prefix {model_prefix}")
|
||||||
|
|
||||||
@ -793,36 +663,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
# get output text and calculate num tokens based on model / provider
|
# get output text and calculate num tokens based on model / provider
|
||||||
model_prefix = model.split('.')[0]
|
model_prefix = model.split('.')[0]
|
||||||
|
|
||||||
if model_prefix == "amazon":
|
if model_prefix == "ai21":
|
||||||
output = response_body.get("results")[0].get("outputText").strip('\n')
|
|
||||||
prompt_tokens = response_body.get("inputTextTokenCount")
|
|
||||||
completion_tokens = response_body.get("results")[0].get("tokenCount")
|
|
||||||
|
|
||||||
elif model_prefix == "ai21":
|
|
||||||
output = response_body.get('completions')[0].get('data').get('text')
|
output = response_body.get('completions')[0].get('data').get('text')
|
||||||
prompt_tokens = len(response_body.get("prompt").get("tokens"))
|
prompt_tokens = len(response_body.get("prompt").get("tokens"))
|
||||||
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
|
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
|
||||||
|
|
||||||
elif model_prefix == "anthropic":
|
|
||||||
output = response_body.get("completion")
|
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
|
|
||||||
|
|
||||||
elif model_prefix == "cohere":
|
elif model_prefix == "cohere":
|
||||||
output = response_body.get("generations")[0].get("text")
|
output = response_body.get("generations")[0].get("text")
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
|
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
|
||||||
|
|
||||||
elif model_prefix == "meta":
|
|
||||||
output = response_body.get("generation").strip('\n')
|
|
||||||
prompt_tokens = response_body.get("prompt_token_count")
|
|
||||||
completion_tokens = response_body.get("generation_token_count")
|
|
||||||
|
|
||||||
elif model_prefix == "mistral":
|
|
||||||
output = response_body.get("outputs")[0].get("text")
|
|
||||||
prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
|
|
||||||
completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||||
|
|
||||||
@ -893,26 +743,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
payload = json.loads(chunk.get('bytes').decode())
|
payload = json.loads(chunk.get('bytes').decode())
|
||||||
|
|
||||||
model_prefix = model.split('.')[0]
|
model_prefix = model.split('.')[0]
|
||||||
if model_prefix == "amazon":
|
if model_prefix == "cohere":
|
||||||
content_delta = payload.get("outputText").strip('\n')
|
|
||||||
finish_reason = payload.get("completion_reason")
|
|
||||||
|
|
||||||
elif model_prefix == "anthropic":
|
|
||||||
content_delta = payload.get("completion")
|
|
||||||
finish_reason = payload.get("stop_reason")
|
|
||||||
|
|
||||||
elif model_prefix == "cohere":
|
|
||||||
content_delta = payload.get("text")
|
content_delta = payload.get("text")
|
||||||
finish_reason = payload.get("finish_reason")
|
finish_reason = payload.get("finish_reason")
|
||||||
|
|
||||||
elif model_prefix == "mistral":
|
|
||||||
content_delta = payload.get('outputs')[0].get("text")
|
|
||||||
finish_reason = payload.get('outputs')[0].get("stop_reason")
|
|
||||||
|
|
||||||
elif model_prefix == "meta":
|
|
||||||
content_delta = payload.get("generation").strip('\n')
|
|
||||||
finish_reason = payload.get("stop_reason")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
|
||||||
|
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
model: meta.llama3-1-405b-instruct-v1:0
|
||||||
|
label:
|
||||||
|
en_US: Llama 3.1 405B Instruct
|
||||||
|
model_type: llm
|
||||||
|
model_properties:
|
||||||
|
mode: completion
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
default: 0.5
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
default: 0.9
|
||||||
|
- name: max_gen_len
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 512
|
||||||
|
min: 1
|
||||||
|
max: 2048
|
||||||
|
pricing:
|
||||||
|
input: '0.00532'
|
||||||
|
output: '0.016'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
Loading…
x
Reference in New Issue
Block a user