fix: remove openllm pypi package because of this package too large (#931)

This commit is contained in:
takatost 2023-08-21 02:12:28 +08:00 committed by GitHub
parent 25264e7852
commit 6c832ee328
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 13 deletions

View File

@ -1,13 +1,13 @@
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import OpenLLM
from langchain.schema import LLMResult from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.openllm import OpenLLM
class OpenLLMModel(BaseLLM): class OpenLLMModel(BaseLLM):
@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM):
client = OpenLLM( client = OpenLLM(
server_url=self.credentials.get('server_url'), server_url=self.credentials.get('server_url'),
callbacks=self.callbacks, callbacks=self.callbacks,
**self.provider_model_kwargs llm_kwargs=self.provider_model_kwargs
) )
return client return client

View File

@ -1,14 +1,13 @@
import json import json
from typing import Type from typing import Type
from langchain.llms import OpenLLM
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.openllm import OpenLLM
from models.provider import ProviderType from models.provider import ProviderType
@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return: :return:
""" """
return ModelKwargsRules( return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1), temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7), top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0), presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0), frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=128), max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
) )
@classmethod @classmethod
@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider):
} }
llm = OpenLLM( llm = OpenLLM(
max_tokens=10, llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs **credential_kwargs
) )

View File

@ -0,0 +1,87 @@
from __future__ import annotations
import logging
from typing import (
Any,
Dict,
List,
Optional,
)
import requests
from langchain.llms.utils import enforce_stop_tokens
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
class OpenLLM(LLM):
"""OpenLLM, supporting both in-process model
instance and remote OpenLLM servers.
If you have a OpenLLM server running, you can also use it remotely:
.. code-block:: python
from langchain.llms import OpenLLM
llm = OpenLLM(server_url='http://localhost:3000')
llm("What is the difference between a duck and a goose?")
"""
server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to be passed to openllm.LLM"""
@property
def _llm_type(self) -> str:
return "openllm"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> str:
params = {
"prompt": prompt,
"llm_config": self.llm_kwargs
}
headers = {"Content-Type": "application/json"}
response = requests.post(
f'{self.server_url}/v1/generate',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
json_response = response.json()
completion = json_response["responses"][0]
if completion:
completion = completion[len(prompt):]
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
raise NotImplementedError(
"Async call is not supported for OpenLLM at the moment."
)

View File

@ -50,4 +50,3 @@ transformers~=4.31.0
stripe~=5.5.0 stripe~=5.5.0
pandas==1.5.3 pandas==1.5.3
xinference==0.2.0 xinference==0.2.0
openllm~=0.2.26

View File

@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_credentials_valid_or_raise_valid(mocker): def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call',
mocker.patch('langchain.llms.openllm.OpenLLM._call',
return_value="abc") return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker):
def test_is_credentials_valid_or_raise_invalid(mocker): def test_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
# raise CredentialsValidateFailedError if credential is not in credentials # raise CredentialsValidateFailedError if credential is not in credentials
with pytest.raises(CredentialsValidateFailedError): with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(