From 4a55d5729d0114f5ec3afb203f73a032f45fc700 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 22 Nov 2023 01:46:19 +0800 Subject: [PATCH] feat: add anthropic claude-2.1 support (#1591) --- .../providers/anthropic_provider.py | 20 ++++++++++++++++--- api/core/model_providers/rules/anthropic.json | 10 ++++++++-- .../langchain/llms/anthropic_llm.py | 18 ++++++++++++++--- api/requirements.txt | 2 +- .../test_anthropic_provider.py | 8 ++++---- 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/api/core/model_providers/providers/anthropic_provider.py b/api/core/model_providers/providers/anthropic_provider.py index 2f667aed96..c98a56e510 100644 --- a/api/core/model_providers/providers/anthropic_provider.py +++ b/api/core/model_providers/providers/anthropic_provider.py @@ -32,9 +32,12 @@ class AnthropicProvider(BaseModelProvider): if model_type == ModelType.TEXT_GENERATION: return [ { - 'id': 'claude-instant-1', - 'name': 'claude-instant-1', + 'id': 'claude-2.1', + 'name': 'claude-2.1', 'mode': ModelMode.CHAT.value, + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] }, { 'id': 'claude-2', @@ -44,6 +47,11 @@ class AnthropicProvider(BaseModelProvider): ModelFeature.AGENT_THOUGHT.value ] }, + { + 'id': 'claude-instant-1', + 'name': 'claude-instant-1', + 'mode': ModelMode.CHAT.value, + }, ] else: return [] @@ -73,12 +81,18 @@ class AnthropicProvider(BaseModelProvider): :param model_type: :return: """ + model_max_tokens = { + 'claude-instant-1': 100000, + 'claude-2': 100000, + 'claude-2.1': 200000, + } + return ModelKwargsRules( temperature=KwargRule[float](min=0, max=1, default=1, precision=2), top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0), + max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=model_max_tokens.get(model_name, 100000), default=256, precision=0), ) @classmethod diff --git a/api/core/model_providers/rules/anthropic.json b/api/core/model_providers/rules/anthropic.json index a302b1de13..bb02ce845b 100644 --- a/api/core/model_providers/rules/anthropic.json +++ b/api/core/model_providers/rules/anthropic.json @@ -23,8 +23,14 @@ "currency": "USD" }, "claude-2": { - "prompt": "11.02", - "completion": "32.68", + "prompt": "8.00", + "completion": "24.00", + "unit": "0.000001", + "currency": "USD" + }, + "claude-2.1": { + "prompt": "8.00", + "completion": "24.00", "unit": "0.000001", "currency": "USD" } diff --git a/api/core/third_party/langchain/llms/anthropic_llm.py b/api/core/third_party/langchain/llms/anthropic_llm.py index 3513bbbe7d..9dfce8e435 100644 --- a/api/core/third_party/langchain/llms/anthropic_llm.py +++ b/api/core/third_party/langchain/llms/anthropic_llm.py @@ -1,7 +1,7 @@ from typing import Dict -from httpx import Limits from langchain.chat_models import ChatAnthropic +from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage from langchain.utils import get_from_dict_or_env, check_package_version from pydantic import root_validator @@ -29,8 +29,7 @@ class AnthropicLLM(ChatAnthropic): base_url=values["anthropic_api_url"], api_key=values["anthropic_api_key"], timeout=values["default_request_timeout"], - max_retries=0, - connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100), + max_retries=0 ) values["async_client"] = anthropic.AsyncAnthropic( base_url=values["anthropic_api_url"], @@ -46,3 +45,16 @@ class AnthropicLLM(ChatAnthropic): "Please it install it with `pip install anthropic`." ) return values + + def _convert_one_message_to_text(self, message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"{self.HUMAN_PROMPT} {message.content}" + elif isinstance(message, AIMessage): + message_text = f"{self.AI_PROMPT} {message.content}" + elif isinstance(message, SystemMessage): + message_text = f"{message.content}" + else: + raise ValueError(f"Got unknown type {message}") + return message_text diff --git a/api/requirements.txt b/api/requirements.txt index 6e32064a05..7b5ed73f8c 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -35,7 +35,7 @@ docx2txt==0.8 pypdfium2==4.16.0 resend~=0.5.1 pyjwt~=2.6.0 -anthropic~=0.3.4 +anthropic~=0.7.2 newspaper3k==0.2.8 google-api-python-client==2.90.0 wikipedia==1.4.0 diff --git a/api/tests/unit_tests/model_providers/test_anthropic_provider.py b/api/tests/unit_tests/model_providers/test_anthropic_provider.py index ea4b62a20a..d4cc9beaaa 100644 --- a/api/tests/unit_tests/model_providers/test_anthropic_provider.py +++ b/api/tests/unit_tests/model_providers/test_anthropic_provider.py @@ -31,12 +31,12 @@ def mock_chat_generate_invalid(messages: List[BaseMessage], run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any): raise anthropic.APIStatusError('Invalid credentials', - request=httpx._models.Request( - method='POST', - url='https://api.anthropic.com/v1/completions', - ), response=httpx._models.Response( status_code=401, + request=httpx._models.Request( + method='POST', + url='https://api.anthropic.com/v1/completions', + ) ), body=None )