feat: optimize hf inference endpoint (#975)

This commit is contained in:
takatost 2023-08-23 19:47:50 +08:00 committed by GitHub
parent 1fc57d7358
commit a76fde3d23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 11 deletions

View File

@ -1,16 +1,14 @@
import decimal
from functools import wraps
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain import HuggingFaceHub from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import HuggingFaceEndpoint
from langchain.schema import LLMResult from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
class HuggingfaceHubModel(BaseLLM): class HuggingfaceHubModel(BaseLLM):
@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
client = HuggingFaceEndpoint( client = HuggingFaceEndpointLLM(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'], endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task='text2text-generation', task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs, model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks, callbacks=self.callbacks
) )
else: else:
client = HuggingFaceHub( client = HuggingFaceHub(

View File

@ -2,7 +2,6 @@ import json
from typing import Type from typing import Type
from huggingface_hub import HfApi from huggingface_hub import HfApi
from langchain.llms import HuggingFaceEndpoint
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from models.provider import ProviderType from models.provider import ProviderType
@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'huggingfacehub_endpoint_url' not in credentials: if 'huggingfacehub_endpoint_url' not in credentials:
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.') raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
try: try:
llm = HuggingFaceEndpoint( llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'], endpoint_url=credentials['huggingfacehub_endpoint_url'],
task="text2text-generation", task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token'] huggingfacehub_api_token=credentials['huggingfacehub_api_token']
) )
@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
} }
credentials = json.loads(provider_model.encrypted_config) credentials = json.loads(provider_model.encrypted_config)
if 'task_type' not in credentials:
credentials['task_type'] = 'text-generation'
if credentials['huggingfacehub_api_token']: if credentials['huggingfacehub_api_token']:
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token( credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
self.provider.tenant_id, self.provider.tenant_id,

View File

@ -0,0 +1,39 @@
from typing import Dict
from langchain.llms import HuggingFaceEndpoint
from pydantic import Extra, root_validator
from langchain.utils import get_from_dict_or_env
class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
"""HuggingFace Endpoint models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation` and `text2text-generation` for now.
Example:
.. code-block:: python
from langchain.llms import HuggingFaceEndpoint
endpoint_url = (
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
)
hf = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
huggingfacehub_api_token="my-api-key"
)
"""
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
values["huggingfacehub_api_token"] = huggingfacehub_api_token
return values

View File

@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = { INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
'huggingfacehub_api_type': 'inference_endpoints', 'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'valid_key', 'huggingfacehub_api_token': 'valid_key',
'huggingfacehub_endpoint_url': 'valid_url' 'huggingfacehub_endpoint_url': 'valid_url',
'task_type': 'text-generation'
} }
def encrypt_side_effect(tenant_id, encrypt_key): def encrypt_side_effect(tenant_id, encrypt_key):