diff --git a/api/core/model_providers/models/llm/anthropic_model.py b/api/core/model_providers/models/llm/anthropic_model.py index 62a9d992ba..5ba7e29c1d 100644 --- a/api/core/model_providers/models/llm/anthropic_model.py +++ b/api/core/model_providers/models/llm/anthropic_model.py @@ -1,11 +1,8 @@ -import decimal import logging -from functools import wraps from typing import List, Optional, Any import anthropic from langchain.callbacks.manager import Callbacks -from langchain.chat_models import ChatAnthropic from langchain.schema import LLMResult from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ @@ -13,6 +10,7 @@ from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM class AnthropicModel(BaseLLM): @@ -20,7 +18,7 @@ class AnthropicModel(BaseLLM): def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) - return ChatAnthropic( + return AnthropicLLM( model=self.name, streaming=self.streaming, callbacks=self.callbacks, diff --git a/api/core/model_providers/providers/anthropic_provider.py b/api/core/model_providers/providers/anthropic_provider.py index 8bab7bb251..e1ef06d140 100644 --- a/api/core/model_providers/providers/anthropic_provider.py +++ b/api/core/model_providers/providers/anthropic_provider.py @@ -5,7 +5,6 @@ from typing import Type, Optional import anthropic from flask import current_app -from langchain.chat_models import ChatAnthropic from langchain.schema import HumanMessage from core.helper import encrypter @@ -16,6 +15,7 @@ from core.model_providers.models.llm.anthropic_model import AnthropicModel from core.model_providers.models.llm.base import ModelType from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.hosted import hosted_model_providers +from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM from models.provider import ProviderType @@ -92,7 +92,7 @@ class AnthropicProvider(BaseModelProvider): if 'anthropic_api_url' in credentials: credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url'] - chat_llm = ChatAnthropic( + chat_llm = AnthropicLLM( model='claude-instant-1', max_tokens_to_sample=10, temperature=0, diff --git a/api/core/third_party/langchain/llms/anthropic_llm.py b/api/core/third_party/langchain/llms/anthropic_llm.py new file mode 100644 index 0000000000..20403bb85b --- /dev/null +++ b/api/core/third_party/langchain/llms/anthropic_llm.py @@ -0,0 +1,47 @@ +from typing import Dict + +from httpx import Limits +from langchain.chat_models import ChatAnthropic +from langchain.utils import get_from_dict_or_env, check_package_version +from pydantic import root_validator + + +class AnthropicLLM(ChatAnthropic): + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["anthropic_api_key"] = get_from_dict_or_env( + values, "anthropic_api_key", "ANTHROPIC_API_KEY" + ) + # Get custom api url from environment. + values["anthropic_api_url"] = get_from_dict_or_env( + values, + "anthropic_api_url", + "ANTHROPIC_API_URL", + default="https://api.anthropic.com", + ) + + try: + import anthropic + + check_package_version("anthropic", gte_version="0.3") + values["client"] = anthropic.Anthropic( + base_url=values["anthropic_api_url"], + api_key=values["anthropic_api_key"], + timeout=values["default_request_timeout"], + connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100), + ) + values["async_client"] = anthropic.AsyncAnthropic( + base_url=values["anthropic_api_url"], + api_key=values["anthropic_api_key"], + timeout=values["default_request_timeout"], + ) + values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT + values["AI_PROMPT"] = anthropic.AI_PROMPT + values["count_tokens"] = values["client"].count_tokens + except ImportError: + raise ImportError( + "Could not import anthropic python package. " + "Please it install it with `pip install anthropic`." + ) + return values