mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 18:05:59 +08:00
feat: add Anthropic claude-3 models support (#2684)
This commit is contained in:
parent
6a6133c102
commit
5c258e212c
@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider):
|
|||||||
|
|
||||||
# Use `claude-instant-1` model for validate,
|
# Use `claude-instant-1` model for validate,
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='claude-instant-1',
|
model='claude-instant-1.2',
|
||||||
credentials=credentials
|
credentials=credentials
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
|
@ -2,8 +2,8 @@ provider: anthropic
|
|||||||
label:
|
label:
|
||||||
en_US: Anthropic
|
en_US: Anthropic
|
||||||
description:
|
description:
|
||||||
en_US: Anthropic’s powerful models, such as Claude 2 and Claude Instant.
|
en_US: Anthropic’s powerful models, such as Claude 3.
|
||||||
zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant。
|
zh_Hans: Anthropic 的强大模型,例如 Claude 3。
|
||||||
icon_small:
|
icon_small:
|
||||||
en_US: icon_s_en.svg
|
en_US: icon_s_en.svg
|
||||||
icon_large:
|
icon_large:
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
- claude-3-opus-20240229
|
||||||
|
- claude-3-sonnet-20240229
|
||||||
|
- claude-2.1
|
||||||
|
- claude-instant-1.2
|
||||||
|
- claude-2
|
||||||
|
- claude-instant-1
|
@ -34,3 +34,4 @@ pricing:
|
|||||||
output: '24.00'
|
output: '24.00'
|
||||||
unit: '0.000001'
|
unit: '0.000001'
|
||||||
currency: USD
|
currency: USD
|
||||||
|
deprecated: true
|
||||||
|
@ -0,0 +1,37 @@
|
|||||||
|
model: claude-3-opus-20240229
|
||||||
|
label:
|
||||||
|
en_US: claude-3-opus-20240229
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 200000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 4096
|
||||||
|
min: 1
|
||||||
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '15.00'
|
||||||
|
output: '75.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,37 @@
|
|||||||
|
model: claude-3-sonnet-20240229
|
||||||
|
label:
|
||||||
|
en_US: claude-3-sonnet-20240229
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 200000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 4096
|
||||||
|
min: 1
|
||||||
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '3.00'
|
||||||
|
output: '15.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -0,0 +1,35 @@
|
|||||||
|
model: claude-instant-1.2
|
||||||
|
label:
|
||||||
|
en_US: claude-instant-1.2
|
||||||
|
model_type: llm
|
||||||
|
features: [ ]
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 100000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 4096
|
||||||
|
min: 1
|
||||||
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '1.63'
|
||||||
|
output: '5.51'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -33,3 +33,4 @@ pricing:
|
|||||||
output: '5.51'
|
output: '5.51'
|
||||||
unit: '0.000001'
|
unit: '0.000001'
|
||||||
currency: USD
|
currency: USD
|
||||||
|
deprecated: true
|
||||||
|
@ -1,18 +1,32 @@
|
|||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
|
import requests
|
||||||
from anthropic import Anthropic, Stream
|
from anthropic import Anthropic, Stream
|
||||||
from anthropic.types import Completion, completion_create_params
|
from anthropic.types import (
|
||||||
|
ContentBlockDeltaEvent,
|
||||||
|
Message,
|
||||||
|
MessageDeltaEvent,
|
||||||
|
MessageStartEvent,
|
||||||
|
MessageStopEvent,
|
||||||
|
MessageStreamEvent,
|
||||||
|
completion_create_params,
|
||||||
|
)
|
||||||
from httpx import Timeout
|
from httpx import Timeout
|
||||||
|
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContentType,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
@ -35,6 +49,7 @@ if you are not sure about the structure.
|
|||||||
</instructions>
|
</instructions>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
@ -55,54 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# invoke model
|
# invoke model
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||||
|
|
||||||
|
def _chat_generate(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke llm chat model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
# transform credentials to kwargs for model instance
|
||||||
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
|
||||||
|
# transform model parameters from completion api of anthropic to chat api
|
||||||
|
if 'max_tokens_to_sample' in model_parameters:
|
||||||
|
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
|
||||||
|
|
||||||
|
# init model client
|
||||||
|
client = Anthropic(**credentials_kwargs)
|
||||||
|
|
||||||
|
extra_model_kwargs = {}
|
||||||
|
if stop:
|
||||||
|
extra_model_kwargs['stop_sequences'] = stop
|
||||||
|
|
||||||
|
if user:
|
||||||
|
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
|
||||||
|
|
||||||
|
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
|
||||||
|
|
||||||
|
if system:
|
||||||
|
extra_model_kwargs['system'] = system
|
||||||
|
|
||||||
|
# chat model
|
||||||
|
response = client.messages.create(
|
||||||
|
model=model,
|
||||||
|
messages=prompt_message_dicts,
|
||||||
|
stream=stream,
|
||||||
|
**model_parameters,
|
||||||
|
**extra_model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper for invoking large language model
|
Code block mode wrapper for invoking large language model
|
||||||
"""
|
"""
|
||||||
if 'response_format' in model_parameters and model_parameters['response_format']:
|
if 'response_format' in model_parameters and model_parameters['response_format']:
|
||||||
stop = stop or []
|
stop = stop or []
|
||||||
self._transform_json_prompts(
|
# chat model
|
||||||
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
|
self._transform_chat_json_prompts(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
response_format=model_parameters['response_format']
|
||||||
)
|
)
|
||||||
model_parameters.pop('response_format')
|
model_parameters.pop('response_format')
|
||||||
|
|
||||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
def _transform_json_prompts(self, model: str, credentials: dict,
|
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||||
-> None:
|
-> None:
|
||||||
"""
|
"""
|
||||||
Transform json prompts
|
Transform json prompts
|
||||||
"""
|
"""
|
||||||
if "```\n" not in stop:
|
if "```\n" not in stop:
|
||||||
stop.append("```\n")
|
stop.append("```\n")
|
||||||
|
if "\n```" not in stop:
|
||||||
|
stop.append("\n```")
|
||||||
|
|
||||||
# check if there is a system message
|
# check if there is a system message
|
||||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
# override the system message
|
# override the system message
|
||||||
prompt_messages[0] = SystemPromptMessage(
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||||
.replace("{{instructions}}", prompt_messages[0].content)
|
.replace("{{instructions}}", prompt_messages[0].content)
|
||||||
.replace("{{block}}", response_format)
|
.replace("{{block}}", response_format)
|
||||||
)
|
)
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||||
else:
|
else:
|
||||||
# insert the system message
|
# insert the system message
|
||||||
prompt_messages.insert(0, SystemPromptMessage(
|
prompt_messages.insert(0, SystemPromptMessage(
|
||||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||||
.replace("{{block}}", response_format)
|
.replace("{{block}}", response_format)
|
||||||
))
|
))
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||||
prompt_messages.append(AssistantPromptMessage(
|
|
||||||
content=f"```{response_format}\n"
|
|
||||||
))
|
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
@ -129,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._generate(
|
self._chat_generate(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
prompt_messages=[
|
prompt_messages=[
|
||||||
@ -137,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
],
|
],
|
||||||
model_parameters={
|
model_parameters={
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_tokens_to_sample": 20,
|
"max_tokens": 20,
|
||||||
},
|
},
|
||||||
stream=False
|
stream=False
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Handle llm chat response
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: credentials kwargs
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param model_parameters: model parameters
|
|
||||||
:param stop: stop words
|
|
||||||
:param stream: is stream response
|
|
||||||
:param user: unique user id
|
|
||||||
:return: full response or stream response chunk generator result
|
|
||||||
"""
|
|
||||||
# transform credentials to kwargs for model instance
|
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
|
||||||
|
|
||||||
client = Anthropic(**credentials_kwargs)
|
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
|
||||||
if stop:
|
|
||||||
extra_model_kwargs['stop_sequences'] = stop
|
|
||||||
|
|
||||||
if user:
|
|
||||||
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
|
|
||||||
|
|
||||||
response = client.completions.create(
|
|
||||||
model=model,
|
|
||||||
prompt=self._convert_messages_to_prompt_anthropic(prompt_messages),
|
|
||||||
stream=stream,
|
|
||||||
**model_parameters,
|
|
||||||
**extra_model_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
|
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
|
||||||
"""
|
|
||||||
Handle llm response
|
|
||||||
|
|
||||||
:param model: model name
|
:param model: model name
|
||||||
:param credentials: credentials
|
:param credentials: credentials
|
||||||
@ -198,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=response.completion
|
content=response.content[0].text
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
if response.usage:
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
# transform usage
|
||||||
|
prompt_tokens = response.usage.input_tokens
|
||||||
|
completion_tokens = response.usage.output_tokens
|
||||||
|
else:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
|
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
result = LLMResult(
|
response = LLMResult(
|
||||||
model=response.model,
|
model=response.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
usage=usage,
|
usage=usage
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return response
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
|
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
response: Stream[MessageStreamEvent],
|
||||||
|
prompt_messages: list[PromptMessage]) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm stream response
|
Handle llm chat stream response
|
||||||
|
|
||||||
:param model: model name
|
:param model: model name
|
||||||
:param credentials: credentials
|
|
||||||
:param response: response
|
:param response: response
|
||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response chunk generator result
|
:return: llm response chunk generator
|
||||||
"""
|
"""
|
||||||
index = -1
|
full_assistant_content = ''
|
||||||
|
return_model = None
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
finish_reason = None
|
||||||
|
index = 0
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
content = chunk.completion
|
if isinstance(chunk, MessageStartEvent):
|
||||||
if chunk.stop_reason is None and (content is None or content == ''):
|
return_model = chunk.message.model
|
||||||
continue
|
input_tokens = chunk.message.usage.input_tokens
|
||||||
|
elif isinstance(chunk, MessageDeltaEvent):
|
||||||
# transform assistant message to prompt message
|
output_tokens = chunk.usage.output_tokens
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
finish_reason = chunk.delta.stop_reason
|
||||||
content=content if content else '',
|
elif isinstance(chunk, MessageStopEvent):
|
||||||
)
|
|
||||||
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
if chunk.stop_reason is not None:
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=chunk.model,
|
model=return_model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index + 1,
|
||||||
message=assistant_prompt_message,
|
message=AssistantPromptMessage(
|
||||||
finish_reason=chunk.stop_reason,
|
content=''
|
||||||
|
),
|
||||||
|
finish_reason=finish_reason,
|
||||||
usage=usage
|
usage=usage
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||||
|
chunk_text = chunk.delta.text if chunk.delta.text else ''
|
||||||
|
full_assistant_content += chunk_text
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=chunk_text
|
||||||
|
)
|
||||||
|
|
||||||
|
index = chunk.index
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=chunk.model,
|
model=return_model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=chunk.index,
|
||||||
message=assistant_prompt_message
|
message=assistant_prompt_message,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -289,6 +337,80 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return credentials_kwargs
|
return credentials_kwargs
|
||||||
|
|
||||||
|
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
||||||
|
"""
|
||||||
|
Convert prompt messages to dict list and system
|
||||||
|
"""
|
||||||
|
system = ""
|
||||||
|
prompt_message_dicts = []
|
||||||
|
|
||||||
|
for message in prompt_messages:
|
||||||
|
if isinstance(message, SystemPromptMessage):
|
||||||
|
system += message.content + ("\n" if not system else "")
|
||||||
|
else:
|
||||||
|
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||||
|
|
||||||
|
return system, prompt_message_dicts
|
||||||
|
|
||||||
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||||
|
"""
|
||||||
|
Convert PromptMessage to dict
|
||||||
|
"""
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
else:
|
||||||
|
sub_messages = []
|
||||||
|
for message_content in message.content:
|
||||||
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
|
message_content = cast(TextPromptMessageContent, message_content)
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "text",
|
||||||
|
"text": message_content.data
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
|
message_content = cast(ImagePromptMessageContent, message_content)
|
||||||
|
if not message_content.data.startswith("data:"):
|
||||||
|
# fetch image data from url
|
||||||
|
try:
|
||||||
|
image_content = requests.get(message_content.data).content
|
||||||
|
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||||
|
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||||
|
except Exception as ex:
|
||||||
|
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||||
|
else:
|
||||||
|
data_split = message_content.data.split(";base64,")
|
||||||
|
mime_type = data_split[0].replace("data:", "")
|
||||||
|
base64_data = data_split[1]
|
||||||
|
|
||||||
|
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
||||||
|
raise ValueError(f"Unsupported image type {mime_type}, "
|
||||||
|
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
||||||
|
|
||||||
|
sub_message_dict = {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": mime_type,
|
||||||
|
"data": base64_data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sub_messages.append(sub_message_dict)
|
||||||
|
|
||||||
|
message_dict = {"role": "user", "content": sub_messages}
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message = cast(SystemPromptMessage, message)
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a single message to a string.
|
Convert a single message to a string.
|
||||||
|
@ -35,7 +35,7 @@ docx2txt==0.8
|
|||||||
pypdfium2==4.16.0
|
pypdfium2==4.16.0
|
||||||
resend~=0.7.0
|
resend~=0.7.0
|
||||||
pyjwt~=2.8.0
|
pyjwt~=2.8.0
|
||||||
anthropic~=0.7.7
|
anthropic~=0.17.0
|
||||||
newspaper3k==0.2.8
|
newspaper3k==0.2.8
|
||||||
google-api-python-client==2.90.0
|
google-api-python-client==2.90.0
|
||||||
wikipedia==1.4.0
|
wikipedia==1.4.0
|
||||||
|
@ -1,52 +1,87 @@
|
|||||||
import os
|
import os
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Any, Generator, List, Literal, Union
|
from typing import Any, Literal, Union, Iterable
|
||||||
|
|
||||||
|
from anthropic.resources import Messages
|
||||||
|
from anthropic.types.message_delta_event import Delta
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from anthropic import Anthropic
|
from anthropic import Anthropic, Stream
|
||||||
from anthropic._types import NOT_GIVEN, Body, Headers, NotGiven, Query
|
from anthropic.types import MessageParam, Message, MessageStreamEvent, \
|
||||||
from anthropic.resources.completions import Completions
|
ContentBlock, MessageStartEvent, Usage, TextDelta, MessageDeltaEvent, MessageStopEvent, ContentBlockDeltaEvent, \
|
||||||
from anthropic.types import Completion, completion_create_params
|
MessageDeltaUsage
|
||||||
|
|
||||||
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
||||||
|
|
||||||
|
|
||||||
class MockAnthropicClass(object):
|
class MockAnthropicClass(object):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mocked_anthropic_chat_create_sync(model: str) -> Completion:
|
def mocked_anthropic_chat_create_sync(model: str) -> Message:
|
||||||
return Completion(
|
return Message(
|
||||||
completion='hello, I\'m a chatbot from anthropic',
|
id='msg-123',
|
||||||
|
type='message',
|
||||||
|
role='assistant',
|
||||||
|
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
|
||||||
model=model,
|
model=model,
|
||||||
stop_reason='stop_sequence'
|
stop_reason='stop_sequence',
|
||||||
|
usage=Usage(
|
||||||
|
input_tokens=1,
|
||||||
|
output_tokens=1
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mocked_anthropic_chat_create_stream(model: str) -> Generator[Completion, None, None]:
|
def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]:
|
||||||
full_response_text = "hello, I'm a chatbot from anthropic"
|
full_response_text = "hello, I'm a chatbot from anthropic"
|
||||||
|
|
||||||
for i in range(0, len(full_response_text) + 1):
|
yield MessageStartEvent(
|
||||||
sleep(0.1)
|
type='message_start',
|
||||||
if i == len(full_response_text):
|
message=Message(
|
||||||
yield Completion(
|
id='msg-123',
|
||||||
completion='',
|
content=[],
|
||||||
model=model,
|
role='assistant',
|
||||||
stop_reason='stop_sequence'
|
model=model,
|
||||||
)
|
stop_reason=None,
|
||||||
else:
|
type='message',
|
||||||
yield Completion(
|
usage=Usage(
|
||||||
completion=full_response_text[i],
|
input_tokens=1,
|
||||||
model=model,
|
output_tokens=1
|
||||||
stop_reason=''
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def mocked_anthropic(self: Completions, *,
|
index = 0
|
||||||
max_tokens_to_sample: int,
|
for i in range(0, len(full_response_text)):
|
||||||
model: Union[str, Literal["claude-2.1", "claude-instant-1"]],
|
sleep(0.1)
|
||||||
prompt: str,
|
yield ContentBlockDeltaEvent(
|
||||||
stream: Literal[True],
|
type='content_block_delta',
|
||||||
**kwargs: Any
|
delta=TextDelta(text=full_response_text[i], type='text_delta'),
|
||||||
) -> Union[Completion, Generator[Completion, None, None]]:
|
index=index
|
||||||
|
)
|
||||||
|
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
yield MessageDeltaEvent(
|
||||||
|
type='message_delta',
|
||||||
|
delta=Delta(
|
||||||
|
stop_reason='stop_sequence'
|
||||||
|
),
|
||||||
|
usage=MessageDeltaUsage(
|
||||||
|
output_tokens=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield MessageStopEvent(type='message_stop')
|
||||||
|
|
||||||
|
def mocked_anthropic(self: Messages, *,
|
||||||
|
max_tokens: int,
|
||||||
|
messages: Iterable[MessageParam],
|
||||||
|
model: str,
|
||||||
|
stream: Literal[True],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||||
if len(self._client.api_key) < 18:
|
if len(self._client.api_key) < 18:
|
||||||
raise anthropic.AuthenticationError('Invalid API key')
|
raise anthropic.AuthenticationError('Invalid API key')
|
||||||
|
|
||||||
@ -55,12 +90,13 @@ class MockAnthropicClass(object):
|
|||||||
else:
|
else:
|
||||||
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
|
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
||||||
if MOCK:
|
if MOCK:
|
||||||
monkeypatch.setattr(Completions, 'create', MockAnthropicClass.mocked_anthropic)
|
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
if MOCK:
|
if MOCK:
|
||||||
monkeypatch.undo()
|
monkeypatch.undo()
|
||||||
|
@ -15,14 +15,14 @@ def test_validate_credentials(setup_anthropic_mock):
|
|||||||
|
|
||||||
with pytest.raises(CredentialsValidateFailedError):
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
model.validate_credentials(
|
model.validate_credentials(
|
||||||
model='claude-instant-1',
|
model='claude-instant-1.2',
|
||||||
credentials={
|
credentials={
|
||||||
'anthropic_api_key': 'invalid_key'
|
'anthropic_api_key': 'invalid_key'
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
model.validate_credentials(
|
model.validate_credentials(
|
||||||
model='claude-instant-1',
|
model='claude-instant-1.2',
|
||||||
credentials={
|
credentials={
|
||||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||||
}
|
}
|
||||||
@ -33,7 +33,7 @@ def test_invoke_model(setup_anthropic_mock):
|
|||||||
model = AnthropicLargeLanguageModel()
|
model = AnthropicLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='claude-instant-1',
|
model='claude-instant-1.2',
|
||||||
credentials={
|
credentials={
|
||||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
|
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
|
||||||
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
|
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
|
||||||
@ -49,7 +49,7 @@ def test_invoke_model(setup_anthropic_mock):
|
|||||||
model_parameters={
|
model_parameters={
|
||||||
'temperature': 0.0,
|
'temperature': 0.0,
|
||||||
'top_p': 1.0,
|
'top_p': 1.0,
|
||||||
'max_tokens_to_sample': 10
|
'max_tokens': 10
|
||||||
},
|
},
|
||||||
stop=['How'],
|
stop=['How'],
|
||||||
stream=False,
|
stream=False,
|
||||||
@ -64,7 +64,7 @@ def test_invoke_stream_model(setup_anthropic_mock):
|
|||||||
model = AnthropicLargeLanguageModel()
|
model = AnthropicLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='claude-instant-1',
|
model='claude-instant-1.2',
|
||||||
credentials={
|
credentials={
|
||||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||||
},
|
},
|
||||||
@ -78,7 +78,7 @@ def test_invoke_stream_model(setup_anthropic_mock):
|
|||||||
],
|
],
|
||||||
model_parameters={
|
model_parameters={
|
||||||
'temperature': 0.0,
|
'temperature': 0.0,
|
||||||
'max_tokens_to_sample': 100
|
'max_tokens': 100
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
user="abc-123"
|
user="abc-123"
|
||||||
@ -97,7 +97,7 @@ def test_get_num_tokens():
|
|||||||
model = AnthropicLargeLanguageModel()
|
model = AnthropicLargeLanguageModel()
|
||||||
|
|
||||||
num_tokens = model.get_num_tokens(
|
num_tokens = model.get_num_tokens(
|
||||||
model='claude-instant-1',
|
model='claude-instant-1.2',
|
||||||
credentials={
|
credentials={
|
||||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||||
},
|
},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user