From 6c832ee32831283fe9ff1be9346906bf8d6fd53f Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 21 Aug 2023 02:12:28 +0800 Subject: [PATCH] fix: remove openllm pypi package because of this package too large (#931) --- .../models/llm/openllm_model.py | 4 +- .../providers/openllm_provider.py | 11 +-- .../third_party/langchain/llms/openllm.py | 87 +++++++++++++++++++ api/requirements.txt | 3 +- .../model_providers/test_openllm_provider.py | 5 +- 5 files changed, 97 insertions(+), 13 deletions(-) create mode 100644 api/core/third_party/langchain/llms/openllm.py diff --git a/api/core/model_providers/models/llm/openllm_model.py b/api/core/model_providers/models/llm/openllm_model.py index eba0d44eb5..2f9876a92b 100644 --- a/api/core/model_providers/models/llm/openllm_model.py +++ b/api/core/model_providers/models/llm/openllm_model.py @@ -1,13 +1,13 @@ from typing import List, Optional, Any from langchain.callbacks.manager import Callbacks -from langchain.llms import OpenLLM from langchain.schema import LLMResult from core.model_providers.error import LLMBadRequestError from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.openllm import OpenLLM class OpenLLMModel(BaseLLM): @@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM): client = OpenLLM( server_url=self.credentials.get('server_url'), callbacks=self.callbacks, - **self.provider_model_kwargs + llm_kwargs=self.provider_model_kwargs ) return client diff --git a/api/core/model_providers/providers/openllm_provider.py b/api/core/model_providers/providers/openllm_provider.py index efcc62c52e..5abb5efa63 100644 --- a/api/core/model_providers/providers/openllm_provider.py +++ b/api/core/model_providers/providers/openllm_provider.py @@ -1,14 +1,13 @@ import json from typing import Type -from langchain.llms import OpenLLM - from core.helper import encrypter from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.llm.openllm_model import OpenLLMModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.models.base import BaseProviderModel +from core.third_party.langchain.llms.openllm import OpenLLM from models.provider import ProviderType @@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider): :return: """ return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=2, default=1), + temperature=KwargRule[float](min=0.01, max=2, default=1), top_p=KwargRule[float](min=0, max=1, default=0.7), presence_penalty=KwargRule[float](min=-2, max=2, default=0), frequency_penalty=KwargRule[float](min=-2, max=2, default=0), - max_tokens=KwargRule[int](min=10, max=4000, default=128), + max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128), ) @classmethod @@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider): } llm = OpenLLM( - max_tokens=10, + llm_kwargs={ + 'max_new_tokens': 10 + }, **credential_kwargs ) diff --git a/api/core/third_party/langchain/llms/openllm.py b/api/core/third_party/langchain/llms/openllm.py new file mode 100644 index 0000000000..95ed80daf4 --- /dev/null +++ b/api/core/third_party/langchain/llms/openllm.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import logging +from typing import ( + Any, + Dict, + List, + Optional, +) + +import requests +from langchain.llms.utils import enforce_stop_tokens +from pydantic import Field + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.llms.base import LLM + +logger = logging.getLogger(__name__) + + +class OpenLLM(LLM): + """OpenLLM, supporting both in-process model + instance and remote OpenLLM servers. + + If you have a OpenLLM server running, you can also use it remotely: + .. code-block:: python + + from langchain.llms import OpenLLM + llm = OpenLLM(server_url='http://localhost:3000') + llm("What is the difference between a duck and a goose?") + """ + + server_url: Optional[str] = None + """Optional server URL that currently runs a LLMServer with 'openllm start'.""" + llm_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Key word arguments to be passed to openllm.LLM""" + + @property + def _llm_type(self) -> str: + return "openllm" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> str: + params = { + "prompt": prompt, + "llm_config": self.llm_kwargs + } + + headers = {"Content-Type": "application/json"} + response = requests.post( + f'{self.server_url}/v1/generate', + headers=headers, + json=params + ) + + if not response.ok: + raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}") + + json_response = response.json() + completion = json_response["responses"][0] + + if completion: + completion = completion[len(prompt):] + + if stop is not None: + completion = enforce_stop_tokens(completion, stop) + + return completion + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + raise NotImplementedError( + "Async call is not supported for OpenLLM at the moment." + ) diff --git a/api/requirements.txt b/api/requirements.txt index 3ae232cd6f..43028d9225 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -49,5 +49,4 @@ huggingface_hub~=0.16.4 transformers~=4.31.0 stripe~=5.5.0 pandas==1.5.3 -xinference==0.2.0 -openllm~=0.2.26 \ No newline at end of file +xinference==0.2.0 \ No newline at end of file diff --git a/api/tests/unit_tests/model_providers/test_openllm_provider.py b/api/tests/unit_tests/model_providers/test_openllm_provider.py index 42609ed360..bd00d2d86f 100644 --- a/api/tests/unit_tests/model_providers/test_openllm_provider.py +++ b/api/tests/unit_tests/model_providers/test_openllm_provider.py @@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key): def test_is_credentials_valid_or_raise_valid(mocker): - mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) - mocker.patch('langchain.llms.openllm.OpenLLM._call', + mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call', return_value="abc") MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( @@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker): def test_is_credentials_valid_or_raise_invalid(mocker): - mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) - # raise CredentialsValidateFailedError if credential is not in credentials with pytest.raises(CredentialsValidateFailedError): MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(