fix: hf hosted inference check (#1128)

This commit is contained in:
takatost 2023-09-09 00:29:48 +08:00 committed by GitHub
parent 681eb1cfcc
commit c4d8bdc3db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 4 deletions

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult from langchain.schema import LLMResult
@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM
class HuggingfaceHubModel(BaseLLM): class HuggingfaceHubModel(BaseLLM):
@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
streaming=streaming streaming=streaming
) )
else: else:
client = HuggingFaceHub( client = HuggingFaceHubLLM(
repo_id=self.name, repo_id=self.name,
task=self.credentials['task_type'], task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs, model_kwargs=provider_model_kwargs,
@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
if 'baichuan' in self.name.lower(): if 'baichuan' in self.name.lower():
return False return False
return True return True
else:
return False

View File

@ -89,7 +89,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
raise CredentialsValidateFailedError('Task Type must be provided.') raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.') raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization.')
try: try:
llm = HuggingFaceEndpointLLM( llm = HuggingFaceEndpointLLM(

View File

@ -0,0 +1,62 @@
from typing import Dict, Optional, List, Any
from huggingface_hub import HfApi, InferenceApi
from langchain import HuggingFaceHub
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.huggingface_hub import VALID_TASKS
from pydantic import root_validator
from langchain.utils import get_from_dict_or_env
class HuggingFaceHubLLM(HuggingFaceHub):
"""HuggingFaceHub models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example:
.. code-block:: python
from langchain.llms import HuggingFaceHub
hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
client = InferenceApi(
repo_id=values["repo_id"],
token=huggingfacehub_api_token,
task=values.get("task"),
)
client.options = {"wait_for_model": False, "use_gpu": False}
values["client"] = client
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
hfapi = HfApi(token=self.huggingfacehub_api_token)
model_info = hfapi.model_info(repo_id=self.repo_id)
if not model_info:
raise ValueError(f"Model {self.repo_id} not found.")
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {self.repo_id} is not a valid task, "
f"must be one of {VALID_TASKS}.")
return super()._call(prompt, stop, run_manager, **kwargs)