mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 13:46:02 +08:00
fix: remove openllm pypi package because of this package too large (#931)
This commit is contained in:
parent
25264e7852
commit
6c832ee328
@ -1,13 +1,13 @@
|
|||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.llms import OpenLLM
|
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
from core.model_providers.error import LLMBadRequestError
|
from core.model_providers.error import LLMBadRequestError
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
from core.model_providers.models.entity.message import PromptMessage
|
from core.model_providers.models.entity.message import PromptMessage
|
||||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||||
|
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||||
|
|
||||||
|
|
||||||
class OpenLLMModel(BaseLLM):
|
class OpenLLMModel(BaseLLM):
|
||||||
@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM):
|
|||||||
client = OpenLLM(
|
client = OpenLLM(
|
||||||
server_url=self.credentials.get('server_url'),
|
server_url=self.credentials.get('server_url'),
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
**self.provider_model_kwargs
|
llm_kwargs=self.provider_model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from langchain.llms import OpenLLM
|
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||||
|
|
||||||
from core.model_providers.models.base import BaseProviderModel
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
|
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return ModelKwargsRules(
|
return ModelKwargsRules(
|
||||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||||
max_tokens=KwargRule[int](min=10, max=4000, default=128),
|
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider):
|
|||||||
}
|
}
|
||||||
|
|
||||||
llm = OpenLLM(
|
llm = OpenLLM(
|
||||||
max_tokens=10,
|
llm_kwargs={
|
||||||
|
'max_new_tokens': 10
|
||||||
|
},
|
||||||
**credential_kwargs
|
**credential_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
87
api/core/third_party/langchain/llms/openllm.py
vendored
Normal file
87
api/core/third_party/langchain/llms/openllm.py
vendored
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenLLM(LLM):
|
||||||
|
"""OpenLLM, supporting both in-process model
|
||||||
|
instance and remote OpenLLM servers.
|
||||||
|
|
||||||
|
If you have a OpenLLM server running, you can also use it remotely:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import OpenLLM
|
||||||
|
llm = OpenLLM(server_url='http://localhost:3000')
|
||||||
|
llm("What is the difference between a duck and a goose?")
|
||||||
|
"""
|
||||||
|
|
||||||
|
server_url: Optional[str] = None
|
||||||
|
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||||
|
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Key word arguments to be passed to openllm.LLM"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "openllm"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
params = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"llm_config": self.llm_kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
response = requests.post(
|
||||||
|
f'{self.server_url}/v1/generate',
|
||||||
|
headers=headers,
|
||||||
|
json=params
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.ok:
|
||||||
|
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
|
||||||
|
|
||||||
|
json_response = response.json()
|
||||||
|
completion = json_response["responses"][0]
|
||||||
|
|
||||||
|
if completion:
|
||||||
|
completion = completion[len(prompt):]
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
completion = enforce_stop_tokens(completion, stop)
|
||||||
|
|
||||||
|
return completion
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Async call is not supported for OpenLLM at the moment."
|
||||||
|
)
|
@ -49,5 +49,4 @@ huggingface_hub~=0.16.4
|
|||||||
transformers~=4.31.0
|
transformers~=4.31.0
|
||||||
stripe~=5.5.0
|
stripe~=5.5.0
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
xinference==0.2.0
|
xinference==0.2.0
|
||||||
openllm~=0.2.26
|
|
@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
|||||||
|
|
||||||
|
|
||||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||||
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
|
mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call',
|
||||||
mocker.patch('langchain.llms.openllm.OpenLLM._call',
|
|
||||||
return_value="abc")
|
return_value="abc")
|
||||||
|
|
||||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker):
|
|||||||
|
|
||||||
|
|
||||||
def test_is_credentials_valid_or_raise_invalid(mocker):
|
def test_is_credentials_valid_or_raise_invalid(mocker):
|
||||||
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
|
|
||||||
|
|
||||||
# raise CredentialsValidateFailedError if credential is not in credentials
|
# raise CredentialsValidateFailedError if credential is not in credentials
|
||||||
with pytest.raises(CredentialsValidateFailedError):
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user