From 5c258e212ca7459119ecbf3f490e293930c85458 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 5 Mar 2024 01:37:42 +0800 Subject: [PATCH] feat: add Anthropic claude-3 models support (#2684) --- .../model_providers/anthropic/anthropic.py | 2 +- .../model_providers/anthropic/anthropic.yaml | 4 +- .../anthropic/llm/_position.yaml | 6 + .../anthropic/llm/claude-2.yaml | 1 + .../anthropic/llm/claude-3-opus-20240229.yaml | 37 ++ .../llm/claude-3-sonnet-20240229.yaml | 37 ++ .../anthropic/llm/claude-instant-1.2.yaml | 35 ++ .../anthropic/llm/claude-instant-1.yaml | 1 + .../model_providers/anthropic/llm/llm.py | 332 ++++++++++++------ api/requirements.txt | 2 +- .../model_runtime/__mock/anthropic.py | 100 ++++-- .../model_runtime/anthropic/test_llm.py | 14 +- 12 files changed, 423 insertions(+), 148 deletions(-) create mode 100644 api/core/model_runtime/model_providers/anthropic/llm/_position.yaml create mode 100644 api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml create mode 100644 api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml create mode 100644 api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.py b/api/core/model_runtime/model_providers/anthropic/anthropic.py index ece6d2a7a4..00a6bbce3b 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.py +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.py @@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider): # Use `claude-instant-1` model for validate, model_instance.validate_credentials( - model='claude-instant-1', + model='claude-instant-1.2', credentials=credentials ) except CredentialsValidateFailedError as ex: diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.yaml b/api/core/model_runtime/model_providers/anthropic/anthropic.yaml index d32b763301..cf41f544ef 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.yaml +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.yaml @@ -2,8 +2,8 @@ provider: anthropic label: en_US: Anthropic description: - en_US: Anthropic’s powerful models, such as Claude 2 and Claude Instant. - zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant。 + en_US: Anthropic’s powerful models, such as Claude 3. + zh_Hans: Anthropic 的强大模型,例如 Claude 3。 icon_small: en_US: icon_s_en.svg icon_large: diff --git a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml new file mode 100644 index 0000000000..e7b002878a --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml @@ -0,0 +1,6 @@ +- claude-3-opus-20240229 +- claude-3-sonnet-20240229 +- claude-2.1 +- claude-instant-1.2 +- claude-2 +- claude-instant-1 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml index 12faf60bc9..1986947129 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml @@ -34,3 +34,4 @@ pricing: output: '24.00' unit: '0.000001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml new file mode 100644 index 0000000000..ab3e92a059 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml @@ -0,0 +1,37 @@ +model: claude-3-opus-20240229 +label: + en_US: claude-3-opus-20240229 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 + - name: response_format + use_template: response_format +pricing: + input: '15.00' + output: '75.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml new file mode 100644 index 0000000000..65cdab9bc6 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml @@ -0,0 +1,37 @@ +model: claude-3-sonnet-20240229 +label: + en_US: claude-3-sonnet-20240229 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 + - name: response_format + use_template: response_format +pricing: + input: '3.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml new file mode 100644 index 0000000000..929a7f8725 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml @@ -0,0 +1,35 @@ +model: claude-instant-1.2 +label: + en_US: claude-instant-1.2 +model_type: llm +features: [ ] +model_properties: + mode: chat + context_size: 100000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 + - name: response_format + use_template: response_format +pricing: + input: '1.63' + output: '5.51' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml index 25d32a09af..5e76d5b1c2 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml @@ -33,3 +33,4 @@ pricing: output: '5.51' unit: '0.000001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 00e5ef6fda..6f9f41ca44 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,18 +1,32 @@ +import base64 +import mimetypes from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast import anthropic +import requests from anthropic import Anthropic, Stream -from anthropic.types import Completion, completion_create_params +from anthropic.types import ( + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + MessageStartEvent, + MessageStopEvent, + MessageStreamEvent, + completion_create_params, +) from httpx import Timeout from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage, ) from core.model_runtime.errors.invoke import ( @@ -35,6 +49,7 @@ if you are not sure about the structure. """ + class AnthropicLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, @@ -55,54 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - + return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + + def _chat_generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + """ + Invoke llm chat model + + :param model: model name + :param credentials: credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + + # transform model parameters from completion api of anthropic to chat api + if 'max_tokens_to_sample' in model_parameters: + model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample') + + # init model client + client = Anthropic(**credentials_kwargs) + + extra_model_kwargs = {} + if stop: + extra_model_kwargs['stop_sequences'] = stop + + if user: + extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) + + system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages) + + if system: + extra_model_kwargs['system'] = system + + # chat model + response = client.messages.create( + model=model, + messages=prompt_message_dicts, + stream=stream, + **model_parameters, + **extra_model_kwargs + ) + + if stream: + return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_chat_generate_response(model, credentials, response, prompt_messages) + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ if 'response_format' in model_parameters and model_parameters['response_format']: stop = stop or [] - self._transform_json_prompts( - model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format'] + # chat model + self._transform_chat_json_prompts( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + response_format=model_parameters['response_format'] ) model_parameters.pop('response_format') return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: """ Transform json prompts """ if "```\n" not in stop: stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") # check if there is a system message if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{block}}", response_format) ) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) else: # insert the system message prompt_messages.insert(0, SystemPromptMessage( content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) + .replace("{{instructions}}", f"Please output a valid {response_format} object.") + .replace("{{block}}", response_format) )) - - prompt_messages.append(AssistantPromptMessage( - content=f"```{response_format}\n" - )) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -129,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: """ try: - self._generate( + self._chat_generate( model=model, credentials=credentials, prompt_messages=[ @@ -137,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ], model_parameters={ "temperature": 0, - "max_tokens_to_sample": 20, + "max_tokens": 20, }, stream=False ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message, + prompt_messages: list[PromptMessage]) -> LLMResult: """ - Invoke large language model - - :param model: model name - :param credentials: credentials kwargs - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :return: full response or stream response chunk generator result - """ - # transform credentials to kwargs for model instance - credentials_kwargs = self._to_credential_kwargs(credentials) - - client = Anthropic(**credentials_kwargs) - - extra_model_kwargs = {} - if stop: - extra_model_kwargs['stop_sequences'] = stop - - if user: - extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) - - response = client.completions.create( - model=model, - prompt=self._convert_messages_to_prompt_anthropic(prompt_messages), - stream=stream, - **model_parameters, - **extra_model_kwargs - ) - - if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) - - return self._handle_generate_response(model, credentials, response, prompt_messages) - - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: - """ - Handle llm response + Handle llm chat response :param model: model name :param credentials: credentials @@ -198,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): """ # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=response.completion + content=response.content[0].text ) # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + if response.usage: + # transform usage + prompt_tokens = response.usage.input_tokens + completion_tokens = response.usage.output_tokens + else: + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) # transform response - result = LLMResult( + response = LLMResult( model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, - usage=usage, + usage=usage ) - return result + return response - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response(self, model: str, credentials: dict, + response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage]) -> Generator: """ - Handle llm stream response + Handle llm chat stream response :param model: model name - :param credentials: credentials :param response: response :param prompt_messages: prompt messages - :return: llm response chunk generator result + :return: llm response chunk generator """ - index = -1 + full_assistant_content = '' + return_model = None + input_tokens = 0 + output_tokens = 0 + finish_reason = None + index = 0 for chunk in response: - content = chunk.completion - if chunk.stop_reason is None and (content is None or content == ''): - continue - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=content if content else '', - ) - - index += 1 - - if chunk.stop_reason is not None: - # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) - + if isinstance(chunk, MessageStartEvent): + return_model = chunk.message.model + input_tokens = chunk.message.usage.input_tokens + elif isinstance(chunk, MessageDeltaEvent): + output_tokens = chunk.usage.output_tokens + finish_reason = chunk.delta.stop_reason + elif isinstance(chunk, MessageStopEvent): # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) yield LLMResultChunk( - model=chunk.model, + model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=chunk.stop_reason, + index=index + 1, + message=AssistantPromptMessage( + content='' + ), + finish_reason=finish_reason, usage=usage ) ) - else: + elif isinstance(chunk, ContentBlockDeltaEvent): + chunk_text = chunk.delta.text if chunk.delta.text else '' + full_assistant_content += chunk_text + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=chunk_text + ) + + index = chunk.index + yield LLMResultChunk( - model=chunk.model, + model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message + index=chunk.index, + message=assistant_prompt_message, ) ) @@ -289,6 +337,80 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return credentials_kwargs + def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: + """ + Convert prompt messages to dict list and system + """ + system = "" + prompt_message_dicts = [] + + for message in prompt_messages: + if isinstance(message, SystemPromptMessage): + system += message.content + ("\n" if not system else "") + else: + prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) + + return system, prompt_message_dicts + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + if not message_content.data.startswith("data:"): + # fetch image data from url + try: + image_content = requests.get(message_content.data).content + mime_type, _ = mimetypes.guess_type(message_content.data) + base64_data = base64.b64encode(image_content).decode('utf-8') + except Exception as ex: + raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + else: + data_split = message_content.data.split(";base64,") + mime_type = data_split[0].replace("data:", "") + base64_data = data_split[1] + + if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + raise ValueError(f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp") + + sub_message_dict = { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": base64_data + } + } + sub_messages.append(sub_message_dict) + + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + return message_dict + def _convert_one_message_to_text(self, message: PromptMessage) -> str: """ Convert a single message to a string. diff --git a/api/requirements.txt b/api/requirements.txt index ae5c77137a..1c3e89e780 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -35,7 +35,7 @@ docx2txt==0.8 pypdfium2==4.16.0 resend~=0.7.0 pyjwt~=2.8.0 -anthropic~=0.7.7 +anthropic~=0.17.0 newspaper3k==0.2.8 google-api-python-client==2.90.0 wikipedia==1.4.0 diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 96fd8f2026..2247d33e24 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -1,52 +1,87 @@ import os from time import sleep -from typing import Any, Generator, List, Literal, Union +from typing import Any, Literal, Union, Iterable + +from anthropic.resources import Messages +from anthropic.types.message_delta_event import Delta import anthropic import pytest from _pytest.monkeypatch import MonkeyPatch -from anthropic import Anthropic -from anthropic._types import NOT_GIVEN, Body, Headers, NotGiven, Query -from anthropic.resources.completions import Completions -from anthropic.types import Completion, completion_create_params +from anthropic import Anthropic, Stream +from anthropic.types import MessageParam, Message, MessageStreamEvent, \ + ContentBlock, MessageStartEvent, Usage, TextDelta, MessageDeltaEvent, MessageStopEvent, ContentBlockDeltaEvent, \ + MessageDeltaUsage MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' + class MockAnthropicClass(object): @staticmethod - def mocked_anthropic_chat_create_sync(model: str) -> Completion: - return Completion( - completion='hello, I\'m a chatbot from anthropic', + def mocked_anthropic_chat_create_sync(model: str) -> Message: + return Message( + id='msg-123', + type='message', + role='assistant', + content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')], model=model, - stop_reason='stop_sequence' + stop_reason='stop_sequence', + usage=Usage( + input_tokens=1, + output_tokens=1 + ) ) @staticmethod - def mocked_anthropic_chat_create_stream(model: str) -> Generator[Completion, None, None]: + def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]: full_response_text = "hello, I'm a chatbot from anthropic" - for i in range(0, len(full_response_text) + 1): - sleep(0.1) - if i == len(full_response_text): - yield Completion( - completion='', - model=model, - stop_reason='stop_sequence' - ) - else: - yield Completion( - completion=full_response_text[i], - model=model, - stop_reason='' + yield MessageStartEvent( + type='message_start', + message=Message( + id='msg-123', + content=[], + role='assistant', + model=model, + stop_reason=None, + type='message', + usage=Usage( + input_tokens=1, + output_tokens=1 ) + ) + ) - def mocked_anthropic(self: Completions, *, - max_tokens_to_sample: int, - model: Union[str, Literal["claude-2.1", "claude-instant-1"]], - prompt: str, - stream: Literal[True], - **kwargs: Any - ) -> Union[Completion, Generator[Completion, None, None]]: + index = 0 + for i in range(0, len(full_response_text)): + sleep(0.1) + yield ContentBlockDeltaEvent( + type='content_block_delta', + delta=TextDelta(text=full_response_text[i], type='text_delta'), + index=index + ) + + index += 1 + + yield MessageDeltaEvent( + type='message_delta', + delta=Delta( + stop_reason='stop_sequence' + ), + usage=MessageDeltaUsage( + output_tokens=1 + ) + ) + + yield MessageStopEvent(type='message_stop') + + def mocked_anthropic(self: Messages, *, + max_tokens: int, + messages: Iterable[MessageParam], + model: str, + stream: Literal[True], + **kwargs: Any + ) -> Union[Message, Stream[MessageStreamEvent]]: if len(self._client.api_key) < 18: raise anthropic.AuthenticationError('Invalid API key') @@ -55,12 +90,13 @@ class MockAnthropicClass(object): else: return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model) + @pytest.fixture def setup_anthropic_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Completions, 'create', MockAnthropicClass.mocked_anthropic) + monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic) yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index ddba2a40ce..b3f6414800 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -15,14 +15,14 @@ def test_validate_credentials(setup_anthropic_mock): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': 'invalid_key' } ) model.validate_credentials( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') } @@ -33,7 +33,7 @@ def test_invoke_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'), 'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL') @@ -49,7 +49,7 @@ def test_invoke_model(setup_anthropic_mock): model_parameters={ 'temperature': 0.0, 'top_p': 1.0, - 'max_tokens_to_sample': 10 + 'max_tokens': 10 }, stop=['How'], stream=False, @@ -64,7 +64,7 @@ def test_invoke_stream_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') }, @@ -78,7 +78,7 @@ def test_invoke_stream_model(setup_anthropic_mock): ], model_parameters={ 'temperature': 0.0, - 'max_tokens_to_sample': 100 + 'max_tokens': 100 }, stream=True, user="abc-123" @@ -97,7 +97,7 @@ def test_get_num_tokens(): model = AnthropicLargeLanguageModel() num_tokens = model.get_num_tokens( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') },